diff --git a/.gitattributes b/.gitattributes new file mode 100755 index 0000000000000000000000000000000000000000..f004b59c173a888e045e0f04bbfd45542ac03358 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,61 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +branchsbm.png filter=lfs diff=lfs merge=lfs -text +branchsbm/branchsbm.png filter=lfs diff=lfs merge=lfs -text +branchsbm/clonidine.png filter=lfs diff=lfs merge=lfs -text +branchsbm/lidar.png filter=lfs diff=lfs merge=lfs -text +branchsbm/mouse.png filter=lfs diff=lfs merge=lfs -text +branchsbm/trametinib.png filter=lfs diff=lfs merge=lfs -text +clonidine.png filter=lfs diff=lfs merge=lfs -text +lidar.png filter=lfs diff=lfs merge=lfs -text +mouse.png filter=lfs diff=lfs merge=lfs -text +trametinib.png filter=lfs diff=lfs merge=lfs -text +data/pca_and_leiden_labels.csv filter=lfs diff=lfs merge=lfs -text +data/mouse_hematopoiesis.csv filter=lfs diff=lfs merge=lfs -text +data/simulation_gene.csv filter=lfs diff=lfs merge=lfs -text +data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv filter=lfs diff=lfs merge=lfs -text +data/Veres_alltime.csv filter=lfs diff=lfs merge=lfs -text +data/Weinreb_alltime.csv filter=lfs diff=lfs merge=lfs -text +data/Weinreb_t2_leiden_clusters.csv filter=lfs diff=lfs merge=lfs -text +data/eb_noscale.csv filter=lfs diff=lfs merge=lfs -text +data/emt.csv filter=lfs diff=lfs merge=lfs -text +*.csv filter=lfs diff=lfs merge=lfs -text +data/*.las filter=lfs diff=lfs merge=lfs -text +*.csv filter=lfs diff=lfs merge=lfs -text +data/*.las filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +assets/veres.png filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..70db352dabd424833a75c09be425104dfb0ff1ff --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +logs/ +wandb/ +__pycache__/ +checkpoints/ +lightining_logs/ +results/ +*.log +*.pyc +lightining_logs/ +figures/ +*.ckpt +*.csv +data/ +extra/ +.vscode/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..1d27760b65e295b067ebf9972b320d3c63683d13 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Sophia Tang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100755 index 0000000000000000000000000000000000000000..4201d6f8c3fbdd71edf64171b30f41e821a70224 --- /dev/null +++ b/README.md @@ -0,0 +1,68 @@ +# [Branched Schrödinger Bridge Matching](https://arxiv.org/abs/2506.09007) (ICLR 2026) 🌳🧬 + +[**Sophia Tang**](https://sophtang.github.io/), [**Yinuo Zhang**](https://www.linkedin.com/in/yinuozhang98/), [**Alexander Tong**](https://www.alextong.net/) and [**Pranam Chatterjee**](https://www.chatterjeelab.com/) + +![BranchSBM](assets/branchsbm_anim.gif) + +This is the repository for [**Branched Schrödinger Bridge Matching**](https://arxiv.org/abs/2506.09007) (ICLR 2026) 🌳🧬. It is partially built on the [**Metric Flow Matching repo**](https://github.com/kkapusniak/metric-flow-matching) ([Kapusniak et al., 2024](https://arxiv.org/abs/2405.14780)). + +Predicting how a population evolves between an initial and final state is central to many problems in generative modeling, from simulating perturbation responses to modelling cell fate decisions 🧫. Existing approaches, such as flow matching and Schrödinger Bridge Matching, effectively learn mappings between two distributions by modelling a single stochastic path. However, these methods are **inherently limited to unimodal transitions and cannot capture branched or divergent evolution from a common origin to multiple distinct outcomes.** + +A key challenge in trajectory matching is reconstructing multi-modal marginals, particularly when modes diverge along distinct dynamical paths . Existing Schrödinger bridge and flow matching frameworks approximate multi-modal distributions by simulating many *independent* particle trajectories, which are susceptible to mode collapse, with particles concentrating on dominant high-density modes or traversing only low-energy intermediate paths. + +To address this, we introduce **Branched Schrödinger Bridge Matching (BranchSBM)** 🌳🧬, a novel framework that learns a set of diverging velocity fields to reconstruct multi-modal target distributions while simultaneously learning growth networks that allocate mass across branches. Guided by a time-dependent potential energy function Vt, BranchSBM captures diverging, energy-minimizing dynamics without requiring intermediate-time supervision and can generate the full branched evolution from a single initial sample. + +🌟 We define the **Branched Generalized Schrödinger Bridge problem** and introduce BranchSBM, a novel matching framework that learns optimal branched trajectories from an initial distribution to multiple target distributions. + +🌟 We derive the Branched Conditional Stochastic Optimal Control (CondSOC) problem as the sum of Unbalanced CondSOC objectives and leverage a multi-stage training algorithm to learn the optimal branching drift and growth fields that transport mass along a branched trajectory. + +🌟 We demonstrate the unique capability of BranchSBM to model dynamic branching trajectories across various real-world problems, including 3D navigation over LiDAR manifolds, modelling differentiating single-cell population dynamics, and simulating heterogeneous cellular responses to drug perturbation. + +# Experiments +Code and instructions to reproduce our results are provided in `/scripts/README`. + +## LiDAR Experiment 🗻 + +As a proof of concept, we first evaluate BranchSBM for navigating branched paths along the surface of a three-dimensional LiDAR manifold, from an initial distribution to two distinct target distributions while remaining on low-altitude regions of the manifold. + +![LiDAR Experiment](assets/lidar.png) + +## Mouse Hematopoiesis and Pancreatic β-Cell Experiment 🧫 + +BranchSBM is uniquely positioned to model single-cell population dynamics where a homogeneous cell population (e.g., progenitor cells) differentiates into several distinct subpopulation branches, each of which independently undergoes growth dynamics. In this experiment, we demonstrate this capability on mouse hematopoiesis data and pancreatic β-cell differentiation data. + +We evaluate BranchSBM on a mouse hematopoiesis scRNA-seq dataset containing three developmental time points representing progenitor cells differentiating into two terminal cell fates. Compared to a single-branch SBM, BranchSBM successfully learns distinct branching trajectories and accurately reconstructs intermediate cell states, demonstrating its ability to recover lineage bifurcation dynamics. + +![Mouse Experiment](assets/mouse.png) + +We evaluate BranchSBM on a pancreatic β-cell differentiation dataset ([Veres et al., 2019](https://www.nature.com/articles/s41586-019-1168-5)) containing 51,274 cells collected across eight time points as human pluripotent stem cells differentiate into pancreatic β-like cells. Cells are projected into a 30-dimensional PCA space, and Leiden clustering is used to define 11 terminal cell populations at the final time point. + +BranchSBM is trained using only samples from the initial and final states, while intermediate distributions are inferred by learning trajectories constrained to the data manifold using an RBF state cost. Compared to baselines, BranchSBM significantly improves reconstruction of both intermediate and terminal distributions, achieving lower Wasserstein distances at validation time points. These results demonstrate that BranchSBM can accurately recover branching differentiation dynamics without intermediate supervision. + +![Veres Experiment](assets/veres.png) + +## Cell Perturbation Modelling Experiment 💉 +Predicting the effects of perturbation on cell state dynamics is a crucial problem for therapeutic design. In this experiment, we leverage BranchSBM to model the **trajectories of a single cell line from a single homogeneous state to multiple heterogeneous states after a drug-induced perturbation**. We demonstrate that BranchSBM is capable of modeling high-dimensional gene expression data and learning branched trajectories that accurately reconstruct diverging perturbed cell populations. + +We extract the data for a single cell line (A-549) under perturbation with Clonidine and Trametinib at 5 µL, selected based on cell abundance and response diversity from the Tahoe-100M dataset. + +For the Clonidine perturbation data, we show that **BranchSBM reconstructs the ground-truth distributions, capturing the location and spread of the dataset**, whereas single-branch SBM fails to differentiate cells in cluster 1 that differ from cluster 0 in higher-dimensional principal components. We also show that BranchSBM can simulate trajectories in high-dimensional state spaces by *scaling up to 150 PCs*. + +![Clonidine Experiment](assets/clonidine.png) + +We further show that BranchSBM can **scale beyond two branches by modeling the perturbed cell population of Trametinib-treated cells**, which diverge into *three distinct clusters*. We trained BranchSBM with three endpoints and single-branch SBM with one endpoint containing all three clusters on the top 50 PCs. + +![Trametinib Experiment](assets/trametinib.png) + + +## Citation +If you find this repository helpful for your publications, please consider citing our paper: +``` +@article{tang2026branchsbm, + title={Branched Schrödinger Bridge Matching}, + author={Tang, Sophia and Zhang, Yinuo and Tong, Alexander and Chatterjee, Pranam}, + journal={14th International Conference on Learning Representations (ICLR 2026)}, + year={2026} +} +``` +To use this repository, you agree to abide by the MIT License. \ No newline at end of file diff --git a/assets/branchsbm.png b/assets/branchsbm.png new file mode 100755 index 0000000000000000000000000000000000000000..e6214c738aebabd196480f96a390d0523236ab36 --- /dev/null +++ b/assets/branchsbm.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a761a2158f833de59d96c2319e33b1197795a0bb6589de549316b34edd30be6c +size 1707253 diff --git a/assets/branchsbm_anim.gif b/assets/branchsbm_anim.gif new file mode 100755 index 0000000000000000000000000000000000000000..320676c5ac92baa4e6e9888760d758bbc658f572 --- /dev/null +++ b/assets/branchsbm_anim.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8664b6455486dc347242cda8a84e951ef4f5e380bd2746f3a415de82cf919fc +size 1493953 diff --git a/assets/clonidine.png b/assets/clonidine.png new file mode 100755 index 0000000000000000000000000000000000000000..703ce0bcae451583d6f4ea54640096ffcd458090 --- /dev/null +++ b/assets/clonidine.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3b84a92120765db993d20bb39bbd56123a7d787aff95e6ee2e1799c44ebdc30 +size 12861177 diff --git a/assets/lidar.png b/assets/lidar.png new file mode 100755 index 0000000000000000000000000000000000000000..31241390c77ff6756b001703514f9e470dddc73a --- /dev/null +++ b/assets/lidar.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20160a1b105aea66d8305ceaead754e2e37c22595b711208652d7059fc76955a +size 2975359 diff --git a/assets/mouse.png b/assets/mouse.png new file mode 100755 index 0000000000000000000000000000000000000000..560e491d606afd32526d5e525ceeedeb936acc0f --- /dev/null +++ b/assets/mouse.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:725f5df7f9d9d2e50200030a6d927ca0e01247361875f60fd26ea01be33c08ef +size 8743832 diff --git a/assets/trametinib.png b/assets/trametinib.png new file mode 100755 index 0000000000000000000000000000000000000000..6a872d7a9792bc56135b046b4cc07abae283a912 --- /dev/null +++ b/assets/trametinib.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18e425c06dd2c2f3bc7fa12386d24f3b9e822b0bb76cc2c377baa13d3db5e82d +size 6049762 diff --git a/assets/veres.png b/assets/veres.png new file mode 100644 index 0000000000000000000000000000000000000000..8e68ddd25764f280337c01a6f91e6480bfcfd5c6 --- /dev/null +++ b/assets/veres.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:64dd688eff9c4c242bc267c617d327b801c2a03d7f91f9d30813b379da6e8fd1 +size 4642252 diff --git a/configs/.DS_Store b/configs/.DS_Store new file mode 100755 index 0000000000000000000000000000000000000000..be36cd203841a6b83c16bcfeb24c0e8d9a42bf5e Binary files /dev/null and b/configs/.DS_Store differ diff --git a/configs/clonidine_100D.yaml b/configs/clonidine_100D.yaml new file mode 100755 index 0000000000000000000000000000000000000000..8aaccb153d74b10e6ee9cc32e721d87b07529b63 --- /dev/null +++ b/configs/clonidine_100D.yaml @@ -0,0 +1,22 @@ +data_type: "tahoe" +data_name: "clonidine100D" +accelerator: "gpu" +hidden_dims_geopath: [1024, 1024, 1024] +hidden_dims_flow: [1024, 1024, 1024] +hidden_dims_growth: [1024, 1024, 1024] +dim: 100 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +n_centers: 300 +kappa: 2 +rho: -2.75 +alpha_metric: 1 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 2 +metric_clusters: 3 \ No newline at end of file diff --git a/configs/clonidine_150D.yaml b/configs/clonidine_150D.yaml new file mode 100755 index 0000000000000000000000000000000000000000..504e922fb4ec61764d309869711039aadc7ced83 --- /dev/null +++ b/configs/clonidine_150D.yaml @@ -0,0 +1,22 @@ +data_type: "tahoe" +data_name: "clonidine150D" +accelerator: "gpu" +hidden_dims_geopath: [1024, 1024, 1024] +hidden_dims_flow: [1024, 1024, 1024] +hidden_dims_growth: [1024, 1024, 1024] +dim: 150 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +n_centers: 300 +kappa: 3 +rho: -2.75 +alpha_metric: 1 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 2 +metric_clusters: 3 \ No newline at end of file diff --git a/configs/clonidine_50D.yaml b/configs/clonidine_50D.yaml new file mode 100755 index 0000000000000000000000000000000000000000..24e507d1f8928dba2ae54b99683b60a5bf01500e --- /dev/null +++ b/configs/clonidine_50D.yaml @@ -0,0 +1,22 @@ +data_type: "tahoe" +data_name: "clonidine50D" +accelerator: "gpu" +hidden_dims_geopath: [1024, 1024, 1024] +hidden_dims_flow: [1024, 1024, 1024] +hidden_dims_growth: [1024, 1024, 1024] +dim: 50 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +n_centers: 150 +kappa: 1.5 +rho: -2.75 +alpha_metric: 1 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 2 +metric_clusters: 3 \ No newline at end of file diff --git a/configs/clonidine_50Dsingle.yaml b/configs/clonidine_50Dsingle.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d76089c02b2fe7983fd9667f442f2b4bb67d6f47 --- /dev/null +++ b/configs/clonidine_50Dsingle.yaml @@ -0,0 +1,22 @@ +data_type: "tahoe" +data_name: "clonidine50Dsingle" +accelerator: "gpu" +hidden_dims_geopath: [1024, 1024, 1024] +hidden_dims_flow: [1024, 1024, 1024] +hidden_dims_growth: [1024, 1024, 1024] +dim: 50 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +n_centers: 150 +kappa: 1.5 +rho: -2.75 +alpha_metric: 1 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 1 +metric_clusters: 2 \ No newline at end of file diff --git a/configs/lidar.yaml b/configs/lidar.yaml new file mode 100755 index 0000000000000000000000000000000000000000..2c0ed84ee97faa545d1914fa637a9818a6d5b641 --- /dev/null +++ b/configs/lidar.yaml @@ -0,0 +1,15 @@ +data_type: "lidar" +data_name: "lidar" +dim: 3 +whiten: true +t_exclude: [] +velocity_metric: "land" +gammas: [0.125] +rho: 0.001 +branchsbm: true +seeds: [42] +patience_geopath: 50 +metric_epochs: 100 +time_geopath: true +branches: 2 +metric_clusters: 3 \ No newline at end of file diff --git a/configs/lidar_single.yaml b/configs/lidar_single.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4fb6f47671f4d9738fadd98048f0ae4d4f37ebb4 --- /dev/null +++ b/configs/lidar_single.yaml @@ -0,0 +1,15 @@ +data_type: "lidar" +data_name: "lidarsingle" +dim: 3 +whiten: true +t_exclude: [] +velocity_metric: "land" +gammas: [0.125] +rho: 0.001 +branchsbm: true +seeds: [42] +patience_geopath: 50 +metric_epochs: 100 +time_geopath: true +branches: 1 +metric_clusters: 2 \ No newline at end of file diff --git a/configs/mouse.yaml b/configs/mouse.yaml new file mode 100755 index 0000000000000000000000000000000000000000..b739a2faebfe3770e9d8ff817dc9819afb5a242f --- /dev/null +++ b/configs/mouse.yaml @@ -0,0 +1,18 @@ +data_type: "scrna" +data_name: "mouse" +hidden_dims_geopath: [64, 64, 64] +hidden_dims_flow: [64, 64, 64] +hidden_dims_growth: [64, 64, 64] +dim: 2 +whiten: false +t_exclude: [] +velocity_metric: "land" +gammas: [0.125] +rho: 0.001 +branchsbm: true +seeds: [42] +patience_geopath: 50 +metric_epochs: 100 +time_geopath: false +branches: 2 +metric_clusters: 2 \ No newline at end of file diff --git a/configs/mouse_single.yaml b/configs/mouse_single.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9b0ea0fe83da25b759c9aeb9e69ea029a4e6c9a6 --- /dev/null +++ b/configs/mouse_single.yaml @@ -0,0 +1,18 @@ +data_type: "scrna" +data_name: "mousesingle" +hidden_dims_geopath: [64, 64, 64] +hidden_dims_flow: [64, 64, 64] +hidden_dims_growth: [64, 64, 64] +dim: 2 +whiten: false +t_exclude: [] +velocity_metric: "land" +gammas: [0.125] +rho: 0.001 +branchsbm: true +seeds: [42] +patience_geopath: 50 +metric_epochs: 100 +time_geopath: true +branches: 1 +metric_clusters: 2 \ No newline at end of file diff --git a/configs/trametinib.yaml b/configs/trametinib.yaml new file mode 100755 index 0000000000000000000000000000000000000000..1d6e7252a4eabd4bbd23f4b0d4611ca171ec552f --- /dev/null +++ b/configs/trametinib.yaml @@ -0,0 +1,22 @@ +data_type: "tahoe" +data_name: "trametinib" +accelerator: "gpu" +hidden_dims_geopath: [1024, 1024, 1024] +hidden_dims_flow: [1024, 1024, 1024] +hidden_dims_growth: [1024, 1024, 1024] +dim: 50 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +n_centers: 150 +kappa: 1.5 +rho: -2.75 +alpha_metric: 1 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 3 +metric_clusters: 4 \ No newline at end of file diff --git a/configs/trametinib_single.yaml b/configs/trametinib_single.yaml new file mode 100755 index 0000000000000000000000000000000000000000..de7ef4f8a4da8415bd6598b0501656e4d4c807c8 --- /dev/null +++ b/configs/trametinib_single.yaml @@ -0,0 +1,22 @@ +data_type: "tahoe" +data_name: "trametinibsingle" +accelerator: "gpu" +hidden_dims_geopath: [1024, 1024, 1024] +hidden_dims_flow: [1024, 1024, 1024] +hidden_dims_growth: [1024, 1024, 1024] +dim: 50 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +n_centers: 150 +kappa: 1.5 +rho: -2.75 +alpha_metric: 1 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 1 +metric_clusters: 2 \ No newline at end of file diff --git a/configs/veres.yaml b/configs/veres.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7f79d753e02fd8d9b581c2dfeac041c6b79f4712 --- /dev/null +++ b/configs/veres.yaml @@ -0,0 +1,25 @@ +data_type: "scrna" +data_name: "veres" +data_path: "data/Veres_alltime.csv" +accelerator: "gpu" +hidden_dims_geopath: [512, 512, 512] +hidden_dims_flow: [512, 512, 512] +hidden_dims_growth: [512, 512, 512] +dim: 30 +t_exclude: [] +time_geopath: true +whiten: false +velocity_metric: "rbf" +metric_patience: 25 +patience: 25 +patience_geopath: 50 +n_centers: 300 +kappa: 2 +rho: 0.001 +alpha_metric: 1.0 +metric_epochs: 100 +branchsbm: true +seeds: [42] +branches: 5 +metric_clusters: 2 +batch_size: 256 diff --git a/dataloaders/.DS_Store b/dataloaders/.DS_Store new file mode 100755 index 0000000000000000000000000000000000000000..29c2de01900f375b869b2bb69febaa2b30525a88 Binary files /dev/null and b/dataloaders/.DS_Store differ diff --git a/dataloaders/clonidine_single_branch.py b/dataloaders/clonidine_single_branch.py new file mode 100755 index 0000000000000000000000000000000000000000..fea21c088f95aada5d42c67a450ea8e53d322367 --- /dev/null +++ b/dataloaders/clonidine_single_branch.py @@ -0,0 +1,265 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader +import pandas as pd +import numpy as np +from functools import partial +from scipy.spatial import cKDTree +from sklearn.cluster import KMeans +from torch.utils.data import TensorDataset + + +class ClonidineSingleBranchDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.split_ratios = args.split_ratios + + self.dim = args.dim + print("dimension") + print(self.dim) + # Path to your combined data + self.data_path = "./data/pca_and_leiden_labels.csv" + self.num_timesteps = 2 + self.args = args + self._prepare_data() + + def _prepare_data(self): + df = pd.read_csv(self.data_path, comment='#') + df = df.iloc[:, 1:] + df = df.replace('', np.nan) + pc_cols = df.columns[:self.dim] + for col in pc_cols: + df[col] = pd.to_numeric(df[col], errors='coerce') + leiden_dmso_col = 'leiden_DMSO_TF_0.0uM' + leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM' + + dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column + clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column + + dmso_data = df[dmso_mask].copy() + clonidine_data = df[clonidine_mask].copy() + + top_clonidine_clusters = ['0.0', '4.0'] + + x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]] + x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]] + + x1_1_coords = x1_1_data[pc_cols].values + x1_2_coords = x1_2_data[pc_cols].values + + x1_1_coords = x1_1_coords.astype(float) + x1_2_coords = x1_2_coords.astype(float) + + # Target size is now the minimum across all three endpoint clusters + target_size = min(len(x1_1_coords), len(x1_2_coords),) + + # Helper function to select points closest to centroid + def select_closest_to_centroid(coords, target_size): + if len(coords) <= target_size: + return coords + + # Calculate centroid + centroid = np.mean(coords, axis=0) + + # Calculate distances to centroid + distances = np.linalg.norm(coords - centroid, axis=1) + + # Get indices of closest points + closest_indices = np.argsort(distances)[:target_size] + + return coords[closest_indices] + + # Sample all endpoint clusters to target size using centroid-based selection + x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size) + x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size) + + dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts() + + # DMSO (unchanged) + largest_dmso_cluster = dmso_cluster_counts.index[0] + dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster] + + dmso_coords = dmso_cluster_data[pc_cols].values + + # Random sampling from largest DMSO cluster to match target size + # For DMSO, we'll also use centroid-based selection for consistency + if len(dmso_coords) >= target_size: + x0_coords = select_closest_to_centroid(dmso_coords, target_size) + else: + # If largest cluster is smaller than target, use all of it and pad with other DMSO cells + remaining_needed = target_size - len(dmso_coords) + other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster] + other_dmso_coords = other_dmso_data[pc_cols].values + + if len(other_dmso_coords) >= remaining_needed: + # Select closest to centroid from other DMSO cells + other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed) + x0_coords = np.vstack([dmso_coords, other_selected]) + else: + # Use all available DMSO cells and reduce target size + all_dmso_coords = dmso_data[pc_cols].values + target_size = min(target_size, len(all_dmso_coords)) + x0_coords = select_closest_to_centroid(all_dmso_coords, target_size) + + # Re-select endpoint clusters with updated target size + x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size) + x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size) + + # No need to resample since we already selected the right number + # The endpoint clusters are already at target_size from centroid-based selection + + self.n_samples = target_size + + x0 = torch.tensor(x0_coords, dtype=torch.float32) + x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32) + x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32) + x1 = torch.cat([x1_1, x1_2], dim=0) + + self.coords_t0 = x0 + self.coords_t1 = x1 + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1)), # t=1 + ]) + + split_index = int(target_size * self.split_ratios[0]) + + if target_size - split_index < self.batch_size: + split_index = target_size - self.batch_size + print('total count is:', target_size) + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + train_x1 = x1[:split_index] + val_x1 = x1[split_index:] + + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0) + + # Updated train dataloaders to include x1_3 + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + all_coords = df[pc_cols].dropna().values.astype(float) + self.dataset = torch.tensor(all_coords, dtype=torch.float32) + self.tree = cKDTree(all_coords) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + # Updated metric samples - now using 4 clusters instead of 3 + #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy()) + km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset.numpy()) + + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + """ + Apply local smoothing based on k-nearest neighbors in the full dataset + This replaces the plane projection for 2D manifold regularization + """ + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) + + # Compute weighted average of neighbors + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + + # Weighted average of neighbors + smoothed = (weights * nearest_pts).sum(dim=1) + + # Blend original point with smoothed version + alpha = 0.3 # How much smoothing to apply + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1': self.coords_t1, + 'time_labels': self.time_labels + } diff --git a/dataloaders/clonidine_v2_data.py b/dataloaders/clonidine_v2_data.py new file mode 100755 index 0000000000000000000000000000000000000000..d74fce022796da0b29b76ee1221607ac5b88ff2e --- /dev/null +++ b/dataloaders/clonidine_v2_data.py @@ -0,0 +1,280 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader +import pandas as pd +import numpy as np +from functools import partial +from scipy.spatial import cKDTree +from sklearn.cluster import KMeans +from torch.utils.data import TensorDataset + + +class ClonidineV2DataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.split_ratios = args.split_ratios + + self.dim = args.dim + print("dimension") + print(self.dim) + # Path to your combined data + self.data_path = "./data/pca_and_leiden_labels.csv" + self.num_timesteps = 2 + self.args = args + self._prepare_data() + + def _prepare_data(self): + df = pd.read_csv(self.data_path, comment='#') + df = df.iloc[:, 1:] + df = df.replace('', np.nan) + pc_cols = df.columns[:self.dim] + for col in pc_cols: + df[col] = pd.to_numeric(df[col], errors='coerce') + leiden_dmso_col = 'leiden_DMSO_TF_0.0uM' + leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM' + + dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column + clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column + + dmso_data = df[dmso_mask].copy() + clonidine_data = df[clonidine_mask].copy() + + top_clonidine_clusters = ['0.0', '4.0'] + + x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]] + x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]] + + x1_1_coords = x1_1_data[pc_cols].values + x1_2_coords = x1_2_data[pc_cols].values + + x1_1_coords = x1_1_coords.astype(float) + x1_2_coords = x1_2_coords.astype(float) + + # Target size is now the minimum across all three endpoint clusters + target_size = min(len(x1_1_coords), len(x1_2_coords),) + + # Helper function to select points closest to centroid + def select_closest_to_centroid(coords, target_size): + if len(coords) <= target_size: + return coords + + # Calculate centroid + centroid = np.mean(coords, axis=0) + + # Calculate distances to centroid + distances = np.linalg.norm(coords - centroid, axis=1) + + # Get indices of closest points + closest_indices = np.argsort(distances)[:target_size] + + return coords[closest_indices] + + # Sample all endpoint clusters to target size using centroid-based selection + x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size) + x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size) + + dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts() + + # DMSO (unchanged) + largest_dmso_cluster = dmso_cluster_counts.index[0] + dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster] + + dmso_coords = dmso_cluster_data[pc_cols].values + + # Random sampling from largest DMSO cluster to match target size + # For DMSO, we'll also use centroid-based selection for consistency + if len(dmso_coords) >= target_size: + x0_coords = select_closest_to_centroid(dmso_coords, target_size) + else: + # If largest cluster is smaller than target, use all of it and pad with other DMSO cells + remaining_needed = target_size - len(dmso_coords) + other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster] + other_dmso_coords = other_dmso_data[pc_cols].values + + if len(other_dmso_coords) >= remaining_needed: + # Select closest to centroid from other DMSO cells + other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed) + x0_coords = np.vstack([dmso_coords, other_selected]) + else: + # Use all available DMSO cells and reduce target size + all_dmso_coords = dmso_data[pc_cols].values + target_size = min(target_size, len(all_dmso_coords)) + x0_coords = select_closest_to_centroid(all_dmso_coords, target_size) + + # Re-select endpoint clusters with updated target size + x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size) + x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size) + + # No need to resample since we already selected the right number + # The endpoint clusters are already at target_size from centroid-based selection + + self.n_samples = target_size + + x0 = torch.tensor(x0_coords, dtype=torch.float32) + x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32) + x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32) + + self.coords_t0 = x0 + self.coords_t1_1 = x1_1 + self.coords_t1_2 = x1_2 + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1_1)), # t=1 + np.ones(len(self.coords_t1_2)), + ]) + + split_index = int(target_size * self.split_ratios[0]) + + if target_size - split_index < self.batch_size: + split_index = target_size - self.batch_size + print('total count is:', target_size) + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + train_x1_1 = x1_1[:split_index] + val_x1_1 = x1_1[split_index:] + train_x1_2 = x1_2[:split_index] + val_x1_2 = x1_2[split_index:] + + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) + train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) + val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) + + # Updated train dataloaders to include x1_3 + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + all_coords = df[pc_cols].dropna().values.astype(float) + self.dataset = torch.tensor(all_coords, dtype=torch.float32) + self.tree = cKDTree(all_coords) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy()) + + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + cluster_2_mask = cluster_labels == 2 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + cluster_2_data = samples[cluster_2_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_2_data, dtype=torch.float32), + batch_size=cluster_2_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + """ + Apply local smoothing based on k-nearest neighbors in the full dataset + This replaces the plane projection for 2D manifold regularization + """ + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) + + # Compute weighted average of neighbors + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + + # Weighted average of neighbors + smoothed = (weights * nearest_pts).sum(dim=1) + + # Blend original point with smoothed version + alpha = 0.3 # How much smoothing to apply + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1_1': self.coords_t1_1, + 't1_2': self.coords_t1_2, + 'time_labels': self.time_labels + } + diff --git a/dataloaders/lidar_data.py b/dataloaders/lidar_data.py new file mode 100755 index 0000000000000000000000000000000000000000..4468db3f2135355ab1fa7e314c2cd47a70e99afd --- /dev/null +++ b/dataloaders/lidar_data.py @@ -0,0 +1,529 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from pytorch_lightning.utilities.combined_loader import CombinedLoader +import laspy +import numpy as np +from scipy.spatial import cKDTree +import math +from functools import partial +from torch.utils.data import TensorDataset + + +class GaussianMM: + def __init__(self, mu, var): + super().__init__() + self.centers = torch.tensor(mu) + self.logstd = torch.tensor(var).log() / 2.0 + self.K = self.centers.shape[0] + + def logprob(self, x): + logprobs = self.normal_logprob( + x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd + ) + logprobs = torch.sum(logprobs, dim=2) + return torch.logsumexp(logprobs, dim=1) - math.log(self.K) + + def normal_logprob(self, z, mean, log_std): + mean = mean + torch.tensor(0.0) + log_std = log_std + torch.tensor(0.0) + c = torch.tensor([math.log(2 * math.pi)]).to(z) + inv_sigma = torch.exp(-log_std) + tmp = (z - mean) * inv_sigma + return -0.5 * (tmp * tmp + 2 * log_std + c) + + def __call__(self, n_samples): + idx = torch.randint(self.K, (n_samples,)).to(self.centers.device) + mean = self.centers[idx] + return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean + +class BranchedLidarDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.data_path = args.data_path + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.p0_mu = [ + [-4.5, -4.0, 0.5], + [-4.2, -3.5, 0.5], + [-4.0, -3.0, 0.5], + [-3.75, -2.5, 0.5], + ] + self.p0_var = 0.02 + + self.p1_1_mu = [ + [-2.5, -0.25, 0.5], + [-2.25, 0.675, 0.5], + [-2, 1.5, 0.5], + ] + self.p1_2_mu = [ + [2, -2, 0.5], + [2.6, -1.25, 0.5], + [3.2, -0.5, 0.5] + ] + + self.p1_var = 0.03 + self.k = 20 + self.n_samples = 5000 + self.num_timesteps = 2 + self.split_ratios = args.split_ratios + self._prepare_data() + + def assign_region(self): + all_centers = { + 0: torch.tensor(self.p0_mu), # Region 0: p0 + 1: torch.tensor(self.p1_1_mu), # Region 1: p1_1 + 2: torch.tensor(self.p1_2_mu), # Region 2: p1_2 + } + + dataset = self.dataset.to(torch.float32) + N = dataset.shape[0] + assignments = torch.zeros(N, dtype=torch.long) + + # For each point, compute min distance to each region's centers + for i in range(N): + point = dataset[i] + min_dist = float("inf") + best_region = 0 + for region, centers in all_centers.items(): + dists = ((centers - point)**2).sum(dim=1) + region_min = dists.min() + if region_min < min_dist: + min_dist = region_min + best_region = region + assignments[i] = best_region + return assignments + + def _prepare_data(self): + las = laspy.read(self.data_path) + # Extract only "ground" points. + self.mask = las.classification == 2 + # Original Preprocessing + x_offset, x_scale = las.header.offsets[0], las.header.scales[0] + y_offset, y_scale = las.header.offsets[1], las.header.scales[1] + z_offset, z_scale = las.header.offsets[2], las.header.scales[2] + dataset = np.vstack( + ( + las.X[self.mask] * x_scale + x_offset, + las.Y[self.mask] * y_scale + y_offset, + las.Z[self.mask] * z_scale + z_offset, + ) + ).transpose() + mi = dataset.min(axis=0, keepdims=True) + ma = dataset.max(axis=0, keepdims=True) + dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] + + self.dataset = torch.tensor(dataset, dtype=torch.float32) + self.tree = cKDTree(dataset) + + x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples) + x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples) + x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples) + + x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian) + x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian) + x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian) + + split_index = int(self.n_samples * self.split_ratios[0]) + + self.scaler = StandardScaler() + if self.whiten: + self.dataset = torch.tensor( + self.scaler.fit_transform(dataset), dtype=torch.float32 + ) + x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32) + x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32) + x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32) + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + + # branches + train_x1_1 = x1_1[:split_index] + print("train_x1_1") + print(train_x1_1.shape) + val_x1_1 = x1_1[split_index:] + train_x1_2 = x1_2[:split_index] + val_x1_2 = x1_2[split_index:] + + self.val_x0 = val_x0 + + # Adjust split_index to ensure minimum validation samples + if self.n_samples - split_index < self.batch_size: + split_index = self.n_samples - self.batch_size + + self.train_dataloaders = { + "x0": DataLoader(train_x0, batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_1": DataLoader(train_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(train_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True), + } + self.val_dataloaders = { + "x0": DataLoader(val_x0, batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1_1": DataLoader(val_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(val_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True), + } + # to edit? + self.test_dataloaders = [ + DataLoader( + self.val_x0, + batch_size=self.val_x0.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + self.dataset, + batch_size=self.dataset.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + points = self.dataset.cpu().numpy() + x, y = points[:, 0], points[:, 1] + # Diagonal-based coordinates (rotated 45°) + u = (x + y) / np.sqrt(2) # along x=y + # start region (A) using u + u_thresh = np.percentile(u, 30) # tweak this threshold to control size + mask_A = u <= u_thresh + + # among the rest, split by x=y diagonal + remaining = ~mask_A + mask_B = remaining & (x < y) # left of diagonal + mask_C = remaining & (x >= y) # right of diagonal + + # Assign dataloaders + self.metric_samples_dataloaders = [ + DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False), + DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False), + DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + return CombinedLoader(self.test_dataloaders) + + def get_tangent_proj(self, points): + w = self.get_tangent_plane(points) + return partial(BranchedLidarDataModule.projection_op, w=w) + + def get_tangent_plane(self, points, temp=1e-3): + points_np = points.detach().cpu().numpy() + _, idx = self.tree.query(points_np, k=self.k) + nearest_pts = self.dataset[idx] + nearest_pts = torch.tensor(nearest_pts).to(points) + + dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + + # Fits plane with least vertical distance. + w = BranchedLidarDataModule.fit_plane(nearest_pts, weights) + return w + + @staticmethod + def fit_plane(points, weights=None): + """Expects points to be of shape (..., 3). + Returns [a, b, c] such that the plane is defined as + ax + by + c = z + """ + D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) + z = points[..., 2] + if weights is not None: + Dtrans = D.transpose(-1, -2) + else: + DW = D * weights + Dtrans = DW.transpose(-1, -2) + w = torch.linalg.solve( + torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) + ).squeeze(-1) + return w + + @staticmethod + def projection_op(x, w): + """Projects points to a plane defined by w.""" + # Normal vector to the tangent plane. + n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) + + pn = torch.sum(x * n, dim=-1, keepdim=True) + nn = torch.sum(n * n, dim=-1, keepdim=True) + + # Offset. + d = w[..., 2:3] + + # Projection of x onto n. + projn_x = ((pn + d) / nn) * n + + # Remove component in the normal direction. + return x - projn_x + +class WeightedBranchedLidarDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.data_path = args.data_path + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.p0_mu = [ + [-4.5, -4.0, 0.5], + [-4.2, -3.5, 0.5], + [-4.0, -3.0, 0.5], + [-3.75, -2.5, 0.5], + ] + self.p0_var = 0.02 + # multiple p1 for each branch + #changed + self.p1_1_mu = [ + [-2.5, -0.25, 0.5], + [-2.25, 0.675, 0.5], + [-2, 1.5, 0.5], + ] + self.p1_2_mu = [ + [2, -2, 0.5], + [2.6, -1.25, 0.5], + [3.2, -0.5, 0.5] + ] + + self.p1_var = 0.03 + self.k = 20 + self.n_samples = 5000 + self.num_timesteps = 2 + self.split_ratios = args.split_ratios + + self.num_timesteps = 2 + self.metric_clusters = 3 + self.args = args + self._prepare_data() + + def _prepare_data(self): + las = laspy.read(self.data_path) + # Extract only "ground" points. + self.mask = las.classification == 2 + # Original Preprocessing + x_offset, x_scale = las.header.offsets[0], las.header.scales[0] + y_offset, y_scale = las.header.offsets[1], las.header.scales[1] + z_offset, z_scale = las.header.offsets[2], las.header.scales[2] + dataset = np.vstack( + ( + las.X[self.mask] * x_scale + x_offset, + las.Y[self.mask] * y_scale + y_offset, + las.Z[self.mask] * z_scale + z_offset, + ) + ).transpose() + mi = dataset.min(axis=0, keepdims=True) + ma = dataset.max(axis=0, keepdims=True) + dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] + + self.dataset = torch.tensor(dataset, dtype=torch.float32) + self.tree = cKDTree(dataset) + + x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples) + x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples) + x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples) + + x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian) + x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian) + x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian) + + split_index = int(self.n_samples * self.split_ratios[0]) + + self.scaler = StandardScaler() + if self.whiten: + self.dataset = torch.tensor( + self.scaler.fit_transform(dataset), dtype=torch.float32 + ) + x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32) + x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32) + x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32) + + self.coords_t0 = x0 + self.coords_t1_1 = x1_1 + self.coords_t1_2 = x1_2 + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1_1)), # t=1 + np.ones(len(self.coords_t1_2)), # t=1 + ]) + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + + # branches + train_x1_1 = x1_1[:split_index] + + val_x1_1 = x1_1[split_index:] + train_x1_2 = x1_2[:split_index] + val_x1_2 = x1_2[split_index:] + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) + train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) + val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) + + # Adjust split_index to ensure minimum validation samples + if self.n_samples - split_index < self.batch_size: + split_index = self.n_samples - self.batch_size + + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + # to edit? + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + points = self.dataset.cpu().numpy() + x, y = points[:, 0], points[:, 1] + # Diagonal-based coordinates (rotated 45°) + u = (x + y) / np.sqrt(2) # along x=y + # start region (A) using u + u_thresh = np.percentile(u, 30) # tweak this threshold to control size + mask_A = u <= u_thresh + + # among the rest, split by x=y diagonal + remaining = ~mask_A + mask_B = remaining & (x < y) # left of diagonal + mask_C = remaining & (x >= y) # right of diagonal + + # Assign dataloaders + self.metric_samples_dataloaders = [ + DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False), + DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False), + DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_tangent_proj(self, points): + w = self.get_tangent_plane(points) + return partial(BranchedLidarDataModule.projection_op, w=w) + + def get_tangent_plane(self, points, temp=1e-3): + points_np = points.detach().cpu().numpy() + _, idx = self.tree.query(points_np, k=self.k) + nearest_pts = self.dataset[idx] + nearest_pts = torch.tensor(nearest_pts).to(points) + + dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + + # Fits plane with least vertical distance. + w = BranchedLidarDataModule.fit_plane(nearest_pts, weights) + return w + + @staticmethod + def fit_plane(points, weights=None): + """Expects points to be of shape (..., 3). + Returns [a, b, c] such that the plane is defined as + ax + by + c = z + """ + D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) + z = points[..., 2] + if weights is not None: + Dtrans = D.transpose(-1, -2) + else: + DW = D * weights + Dtrans = DW.transpose(-1, -2) + w = torch.linalg.solve( + torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) + ).squeeze(-1) + return w + + @staticmethod + def projection_op(x, w): + """Projects points to a plane defined by w.""" + # Normal vector to the tangent plane. + n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) + + pn = torch.sum(x * n, dim=-1, keepdim=True) + nn = torch.sum(n * n, dim=-1, keepdim=True) + + # Offset. + d = w[..., 2:3] + + # Projection of x onto n. + projn_x = ((pn + d) / nn) * n + + # Remove component in the normal direction. + return x - projn_x + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1_1': self.coords_t1_1, + 't1_2': self.coords_t1_2, + 'time_labels': self.time_labels + } + +def get_datamodule(): + datamodule = WeightedBranchedLidarDataModule(args) + datamodule.setup(stage="fit") + return datamodule \ No newline at end of file diff --git a/dataloaders/lidar_data_single.py b/dataloaders/lidar_data_single.py new file mode 100755 index 0000000000000000000000000000000000000000..ffdcf27d46f5a6f4f2357ca1071b763422558869 --- /dev/null +++ b/dataloaders/lidar_data_single.py @@ -0,0 +1,274 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from pytorch_lightning.utilities.combined_loader import CombinedLoader +import laspy +import numpy as np +from scipy.spatial import cKDTree +import math +from functools import partial +from torch.utils.data import TensorDataset + + +class GaussianMM: + def __init__(self, mu, var): + super().__init__() + self.centers = torch.tensor(mu) + self.logstd = torch.tensor(var).log() / 2.0 + self.K = self.centers.shape[0] + + def logprob(self, x): + logprobs = self.normal_logprob( + x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd + ) + logprobs = torch.sum(logprobs, dim=2) + return torch.logsumexp(logprobs, dim=1) - math.log(self.K) + + def normal_logprob(self, z, mean, log_std): + mean = mean + torch.tensor(0.0) + log_std = log_std + torch.tensor(0.0) + c = torch.tensor([math.log(2 * math.pi)]).to(z) + inv_sigma = torch.exp(-log_std) + tmp = (z - mean) * inv_sigma + return -0.5 * (tmp * tmp + 2 * log_std + c) + + def __call__(self, n_samples): + idx = torch.randint(self.K, (n_samples,)).to(self.centers.device) + mean = self.centers[idx] + return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean + +class LidarSingleDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.data_path = args.data_path + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.p0_mu = [ + [-4.5, -4.0, 0.5], + [-4.2, -3.5, 0.5], + [-4.0, -3.0, 0.5], + [-3.75, -2.5, 0.5], + ] + self.p0_var = 0.02 + # multiple p1 for each branch + #changed + self.p1_1_mu = [ + [-2.5, -0.25, 0.5], + [-2.25, 0.675, 0.5], + [-2, 1.5, 0.5], + ] + self.p1_2_mu = [ + [2, -2, 0.5], + [2.6, -1.25, 0.5], + [3.2, -0.5, 0.5] + ] + + self.p1_var = 0.03 + self.k = 20 + self.n_samples = 5000 + self.num_timesteps = 2 + self.split_ratios = args.split_ratios + + self.num_timesteps = 2 + self.metric_clusters = 3 + self.args = args + self._prepare_data() + + def _prepare_data(self): + las = laspy.read(self.data_path) + # Extract only "ground" points. + self.mask = las.classification == 2 + # Original Preprocessing + x_offset, x_scale = las.header.offsets[0], las.header.scales[0] + y_offset, y_scale = las.header.offsets[1], las.header.scales[1] + z_offset, z_scale = las.header.offsets[2], las.header.scales[2] + dataset = np.vstack( + ( + las.X[self.mask] * x_scale + x_offset, + las.Y[self.mask] * y_scale + y_offset, + las.Z[self.mask] * z_scale + z_offset, + ) + ).transpose() + mi = dataset.min(axis=0, keepdims=True) + ma = dataset.max(axis=0, keepdims=True) + dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0] + + self.dataset = torch.tensor(dataset, dtype=torch.float32) + self.tree = cKDTree(dataset) + + x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples) + x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples) + x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples) + + x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian) + x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian) + x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian) + + split_index = int(self.n_samples * self.split_ratios[0]) + + self.scaler = StandardScaler() + if self.whiten: + self.dataset = torch.tensor( + self.scaler.fit_transform(dataset), dtype=torch.float32 + ) + x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32) + x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32) + x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32) + x1 = torch.cat([x1_1, x1_2], dim=0) + + self.coords_t0 = x0 + self.coords_t1 = x1 + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1)), # t=1 + ]) + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + + # branches + train_x1 = x1[:split_index] + val_x1 = x1[split_index:] + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0) + + # Adjust split_index to ensure minimum validation samples + if self.n_samples - split_index < self.batch_size: + split_index = self.n_samples - self.batch_size + + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + # to edit? + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=False), + "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=True, drop_last=False), + } + + points = self.dataset.cpu().numpy() + x, y = points[:, 0], points[:, 1] + # Diagonal-based coordinates (rotated 45°) + u = (x + y) / np.sqrt(2) # along x=y + # start region (A) using u + u_thresh = np.percentile(u, 30) # tweak this threshold to control size + mask_A = u <= u_thresh + + # among the rest, split by x=y diagonal + remaining = ~mask_A + mask_B = remaining & (x < y) # left of diagonal + mask_C = remaining & (x >= y) # right of diagonal + + # Assign dataloaders + self.metric_samples_dataloaders = [ + DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False), + DataLoader(torch.tensor(points[remaining], dtype=torch.float32), batch_size=points[remaining].shape[0], shuffle=False), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_tangent_proj(self, points): + w = self.get_tangent_plane(points) + return partial(LidarSingleDataModule.projection_op, w=w) + + def get_tangent_plane(self, points, temp=1e-3): + points_np = points.detach().cpu().numpy() + _, idx = self.tree.query(points_np, k=self.k) + nearest_pts = self.dataset[idx] + nearest_pts = torch.tensor(nearest_pts).to(points) + + dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + + # Fits plane with least vertical distance. + w = LidarSingleDataModule.fit_plane(nearest_pts, weights) + return w + + @staticmethod + def fit_plane(points, weights=None): + """Expects points to be of shape (..., 3). + Returns [a, b, c] such that the plane is defined as + ax + by + c = z + """ + D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1) + z = points[..., 2] + if weights is not None: + Dtrans = D.transpose(-1, -2) + else: + DW = D * weights + Dtrans = DW.transpose(-1, -2) + w = torch.linalg.solve( + torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1)) + ).squeeze(-1) + return w + + @staticmethod + def projection_op(x, w): + """Projects points to a plane defined by w.""" + # Normal vector to the tangent plane. + n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1) + + pn = torch.sum(x * n, dim=-1, keepdim=True) + nn = torch.sum(n * n, dim=-1, keepdim=True) + + # Offset. + d = w[..., 2:3] + + # Projection of x onto n. + projn_x = ((pn + d) / nn) * n + + # Remove component in the normal direction. + return x - projn_x + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1': self.coords_t1, + 'time_labels': self.time_labels + } diff --git a/dataloaders/mouse_data.py b/dataloaders/mouse_data.py new file mode 100755 index 0000000000000000000000000000000000000000..911d54dc4478c387b1e520f08595c68b1ba05aac --- /dev/null +++ b/dataloaders/mouse_data.py @@ -0,0 +1,453 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader +import numpy as np +from scipy.spatial import cKDTree +import math +from functools import partial +from sklearn.cluster import KMeans, DBSCAN +import matplotlib.pyplot as plt +import pandas as pd +from torch.utils.data import TensorDataset + +class WeightedBranchedCellDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.data_path = args.data_path + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.k = 20 + self.n_samples = 1429 + self.num_timesteps = 2 # t=0, t=1, t=2 + self.split_ratios = args.split_ratios + self.metric_clusters = args.metric_clusters + self.args = args + self._prepare_data() + + + def _prepare_data(self): + print("Preparing cell data in BranchedCellDataModule") + + df = pd.read_csv(self.data_path) + + # Build dictionary of coordinates by time + coords_by_t = { + t: df[df["samples"] == t][["x1","x2"]].values + for t in sorted(df["samples"].unique()) + } + n0 = coords_by_t[0].shape[0] # Number of T=0 points + self.n_samples = n0 # Update n_samples to match actual data if changes + + # Cluster the t=2 cells into two branches + km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2]) + df2 = df[df["samples"] == 2].copy() + df2["branch"] = km.labels_ + + cluster_counts = df2["branch"].value_counts().sort_index() + print(cluster_counts) + + # Sample n0 points from each branch + endpoints = {} + for b in (0, 1): + endpoints[b] = ( + df2[df2["branch"] == b] + .sample(n=n0, random_state=42)[["x1","x2"]] + .values + ) + + x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index + x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32) + x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) # Branch index + x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) # Branch index + + self.coords_t0 = x0 + self.coords_t1 = x_inter + self.coords_t2_1 = x1_1 + self.coords_t2_2 = x1_2 + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1)), # t=1 + np.ones(len(self.coords_t2_1)) * 2, # t=1 + np.ones(len(self.coords_t2_2)) * 2, + ]) + + split_index = int(n0 * self.split_ratios[0]) + + if n0 - split_index < self.batch_size: + split_index = n0 - self.batch_size + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + train_x1_1 = x1_1[:split_index] + val_x1_1 = x1_1[split_index:] + train_x1_2 = x1_2[:split_index] + val_x1_2 = x1_2[split_index:] + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5) + train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5) + val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5) + + if self.n_samples - split_index < self.batch_size: + split_index = self.n_samples - self.batch_size + + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())]) + self.dataset = torch.tensor(all_data, dtype=torch.float32) + self.tree = cKDTree(all_data) + + # if whitening is enabled, need to apply this to the full dataset + #if self.whiten: + #self.scaler = StandardScaler() + #self.dataset = torch.tensor( + #self.scaler.fit_transform(all_data), dtype=torch.float32 + #) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + # Metric Dataloader + # K-means clustering of ALL points into 2 groups + if self.metric_clusters == 3: + km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy()) + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + cluster_2_mask = cluster_labels == 2 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + cluster_2_data = samples[cluster_2_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_2_data, dtype=torch.float32), + batch_size=cluster_2_data.shape[0], + shuffle=False, + drop_last=False, + ), + + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + else: + km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy()) + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + """ + Apply local smoothing based on k-nearest neighbors in the full dataset + This replaces the plane projection for 2D manifold regularization + """ + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) + + # Compute weighted average of neighbors + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + + # Weighted average of neighbors + smoothed = (weights * nearest_pts).sum(dim=1) + + # Blend original point with smoothed version + alpha = 0.3 # How much smoothing to apply + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1': self.coords_t1, + 't2_1': self.coords_t2_1, + 't2_2': self.coords_t2_2, + 'time_labels': self.time_labels + } + + + +class SingleBranchCellDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.data_path = args.data_path + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.k = 20 + self.n_samples = 1429 + self.num_timesteps = 3 # t=0, t=1, t=2 + self.split_ratios = args.split_ratios + self.metric_clusters = 3 + self.args = args + self._prepare_data() + + + def _prepare_data(self): + print("Preparing cell data in BranchedCellDataModule") + + df = pd.read_csv(self.data_path) + + # Build dictionary of coordinates by time + coords_by_t = { + t: df[df["samples"] == t][["x1","x2"]].values + for t in sorted(df["samples"].unique()) + } + n0 = coords_by_t[0].shape[0] # Number of T=0 points + self.n_samples = n0 # Update n_samples to match actual data if changes + + x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index + x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32) + x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) # Branch index + + # Store for get_timepoint_data() + self.coords_t0 = x0 + self.coords_t1 = x_inter + self.coords_t2 = x1 + self.time_labels = np.concatenate([ + np.zeros(len(x0)), + np.ones(len(x_inter)), + np.ones(len(x1)) * 2, + ]) + + split_index = int(n0 * self.split_ratios[0]) + + if n0 - split_index < self.batch_size: + split_index = n0 - self.batch_size + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + train_x1 = x1[:split_index] + val_x1 = x1[split_index:] + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5) + + if self.n_samples - split_index < self.batch_size: + split_index = self.n_samples - self.batch_size + + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())]) + self.dataset = torch.tensor(all_data, dtype=torch.float32) + self.tree = cKDTree(all_data) + + # if whitening is enabled, need to apply this to the full dataset + if self.whiten: + self.scaler = StandardScaler() + self.dataset = torch.tensor( + self.scaler.fit_transform(all_data), dtype=torch.float32 + ) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + # Metric Dataloader + # K-means clustering of ALL points into 2 groups + km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy()) + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + """ + Apply local smoothing based on k-nearest neighbors in the full dataset + This replaces the plane projection for 2D manifold regularization + """ + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) + + # Compute weighted average of neighbors + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + + # Weighted average of neighbors + smoothed = (weights * nearest_pts).sum(dim=1) + + # Blend original point with smoothed version + alpha = 0.3 # How much smoothing to apply + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1': self.coords_t1, + 't2': self.coords_t2, + 'time_labels': self.time_labels + } + +"""def get_datamodule(): + datamodule = WeightedBranchedCellDataModule(args) + datamodule.setup(stage="fit") + return datamodule""" \ No newline at end of file diff --git a/dataloaders/three_branch_data.py b/dataloaders/three_branch_data.py new file mode 100755 index 0000000000000000000000000000000000000000..1240c7caa67168b8f6da90a7946bfda3601ea3e7 --- /dev/null +++ b/dataloaders/three_branch_data.py @@ -0,0 +1,306 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader +import pandas as pd +import numpy as np +from functools import partial +from scipy.spatial import cKDTree +from sklearn.cluster import KMeans +from torch.utils.data import TensorDataset + +class ThreeBranchTahoeDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.split_ratios = args.split_ratios + self.num_timesteps = 2 + self.data_path = f"{args.working_dir}/data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv" + self.args = args + + self._prepare_data() + + def _prepare_data(self): + df = pd.read_csv(self.data_path, comment='#') + df = df.iloc[:, 1:] + df = df.replace('', np.nan) + pc_cols = df.columns[:50] + for col in pc_cols: + df[col] = pd.to_numeric(df[col], errors='coerce') + leiden_dmso_col = 'leiden_DMSO_TF_0.0uM' + leiden_clonidine_col = 'leiden_Trametinib_5.0uM' + + dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column + clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column + + dmso_data = df[dmso_mask].copy() + clonidine_data = df[clonidine_mask].copy() + + # Updated to include all three clusters: 0, 4, and 6 + top_clonidine_clusters = ['1.0', '3.0', '5.0'] + + x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]] + x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]] + x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]] + + x1_1_coords = x1_1_data[pc_cols].values + x1_2_coords = x1_2_data[pc_cols].values + x1_3_coords = x1_3_data[pc_cols].values + + x1_1_coords = x1_1_coords.astype(float) + x1_2_coords = x1_2_coords.astype(float) + x1_3_coords = x1_3_coords.astype(float) + + # Target size is now the minimum across all three endpoint clusters + target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords)) + + # Helper function to select points closest to centroid + def select_closest_to_centroid(coords, target_size): + if len(coords) <= target_size: + return coords + + # Calculate centroid + centroid = np.mean(coords, axis=0) + + # Calculate distances to centroid + distances = np.linalg.norm(coords - centroid, axis=1) + + # Get indices of closest points + closest_indices = np.argsort(distances)[:target_size] + + return coords[closest_indices] + + # Sample all endpoint clusters to target size using centroid-based selection + x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size) + x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size) + x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size) + + dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts() + + # DMSO (unchanged) + largest_dmso_cluster = dmso_cluster_counts.index[0] + dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster] + + dmso_coords = dmso_cluster_data[pc_cols].values + + # Random sampling from largest DMSO cluster to match target size + # For DMSO, we'll also use centroid-based selection for consistency + if len(dmso_coords) >= target_size: + x0_coords = select_closest_to_centroid(dmso_coords, target_size) + else: + # If largest cluster is smaller than target, use all of it and pad with other DMSO cells + remaining_needed = target_size - len(dmso_coords) + other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster] + other_dmso_coords = other_dmso_data[pc_cols].values + + if len(other_dmso_coords) >= remaining_needed: + # Select closest to centroid from other DMSO cells + other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed) + x0_coords = np.vstack([dmso_coords, other_selected]) + else: + # Use all available DMSO cells and reduce target size + all_dmso_coords = dmso_data[pc_cols].values + target_size = min(target_size, len(all_dmso_coords)) + x0_coords = select_closest_to_centroid(all_dmso_coords, target_size) + + # Re-select endpoint clusters with updated target size + x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size) + x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size) + x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size) + + + self.n_samples = target_size + + # for plotting + self.coords_t0 = torch.tensor(x0_coords, dtype=torch.float32) + self.coords_t1_1 = torch.tensor(x1_1_coords, dtype=torch.float32) + self.coords_t1_2 = torch.tensor(x1_2_coords, dtype=torch.float32) + self.coords_t1_3 = torch.tensor(x1_3_coords, dtype=torch.float32) + + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1_1)), # t=1 + np.ones(len(self.coords_t1_2)), # t=1 + np.ones(len(self.coords_t1_3)), # t=1 + ]) + + x0 = torch.tensor(x0_coords, dtype=torch.float32) + x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32) + x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32) + x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32) + + split_index = int(target_size * self.split_ratios[0]) + + if target_size - split_index < self.batch_size: + split_index = target_size - self.batch_size + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + train_x1_1 = x1_1[:split_index] + val_x1_1 = x1_1[split_index:] + train_x1_2 = x1_2[:split_index] + val_x1_2 = x1_2[split_index:] + train_x1_3 = x1_3[:split_index] + val_x1_3 = x1_3[split_index:] + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.603) + train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.255) + train_x1_3_weights = torch.full((train_x1_3.shape[0], 1), fill_value=0.142) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.603) + val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.255) + val_x1_3_weights = torch.full((val_x1_3.shape[0], 1), fill_value=0.142) + + # Updated train dataloaders to include x1_3 + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_3": DataLoader(TensorDataset(train_x1_3, train_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + # Updated val dataloaders to include x1_3 + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1_3": DataLoader(TensorDataset(val_x1_3, val_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + all_coords = df[pc_cols].dropna().values.astype(float) + self.dataset = torch.tensor(all_coords, dtype=torch.float32) + self.tree = cKDTree(all_coords) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + # Updated metric samples - now using 4 clusters instead of 3 + #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy()) + km_all = KMeans(n_clusters=4, random_state=0).fit(self.dataset[:, :3].numpy()) + + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + cluster_2_mask = cluster_labels == 2 + cluster_3_mask = cluster_labels == 3 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + cluster_2_data = samples[cluster_2_mask] + cluster_3_data = samples[cluster_3_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_3_data, dtype=torch.float32), + batch_size=cluster_3_data.shape[0], + shuffle=False, + drop_last=False, + ), + + + DataLoader( + torch.tensor(cluster_2_data, dtype=torch.float32), + batch_size=cluster_2_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + """ + Apply local smoothing based on k-nearest neighbors in the full dataset + This replaces the plane projection for 2D manifold regularization + """ + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) + + # Compute weighted average of neighbors + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + + # Weighted average of neighbors + smoothed = (weights * nearest_pts).sum(dim=1) + + # Blend original point with smoothed version + alpha = 0.3 # How much smoothing to apply + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1_1': self.coords_t1_1, + 't1_2': self.coords_t1_2, + 't1_3': self.coords_t1_3, + 'time_labels': self.time_labels + } + +def get_datamodule(): + from plot.parsers_tahoe import parse_args + args = parse_args() + datamodule = ThreeBranchTahoeDataModule(args) + datamodule.setup(stage="fit") + return datamodule \ No newline at end of file diff --git a/dataloaders/trametinib_single.py b/dataloaders/trametinib_single.py new file mode 100755 index 0000000000000000000000000000000000000000..4c5d03d3f9016b62d6aea9a0169c04136779c89b --- /dev/null +++ b/dataloaders/trametinib_single.py @@ -0,0 +1,268 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader +import pandas as pd +import numpy as np +from functools import partial +from scipy.spatial import cKDTree +from sklearn.cluster import KMeans +from torch.utils.data import TensorDataset + + +class TrametinibSingleBranchDataModule(pl.LightningDataModule): + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.split_ratios = args.split_ratios + self.num_timesteps = 2 + self.data_path = args.data_path + self.args = args + + self._prepare_data() + + def _prepare_data(self): + df = pd.read_csv(self.data_path, comment='#') + df = df.iloc[:, 1:] + df = df.replace('', np.nan) + pc_cols = df.columns[:50] + for col in pc_cols: + df[col] = pd.to_numeric(df[col], errors='coerce') + leiden_dmso_col = 'leiden_DMSO_TF_0.0uM' + leiden_clonidine_col = 'leiden_Trametinib_5.0uM' + + dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column + clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column + + dmso_data = df[dmso_mask].copy() + clonidine_data = df[clonidine_mask].copy() + + # Updated to include all three clusters: 0, 4, and 6 + top_clonidine_clusters = ['1.0', '3.0', '5.0'] + + x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]] + x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]] + x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]] + + x1_1_coords = x1_1_data[pc_cols].values + x1_2_coords = x1_2_data[pc_cols].values + x1_3_coords = x1_3_data[pc_cols].values + + x1_1_coords = x1_1_coords.astype(float) + x1_2_coords = x1_2_coords.astype(float) + x1_3_coords = x1_3_coords.astype(float) + + # Target size is now the minimum across all three endpoint clusters + target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords)) + + # Helper function to select points closest to centroid + def select_closest_to_centroid(coords, target_size): + if len(coords) <= target_size: + return coords + + # Calculate centroid + centroid = np.mean(coords, axis=0) + + # Calculate distances to centroid + distances = np.linalg.norm(coords - centroid, axis=1) + + # Get indices of closest points + closest_indices = np.argsort(distances)[:target_size] + + return coords[closest_indices] + + # Sample all endpoint clusters to target size using centroid-based selection + x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size) + x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size) + x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size) + + dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts() + + # DMSO (unchanged) + largest_dmso_cluster = dmso_cluster_counts.index[0] + dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster] + + dmso_coords = dmso_cluster_data[pc_cols].values + + # Random sampling from largest DMSO cluster to match target size + # For DMSO, we'll also use centroid-based selection for consistency + if len(dmso_coords) >= target_size: + x0_coords = select_closest_to_centroid(dmso_coords, target_size) + else: + # If largest cluster is smaller than target, use all of it and pad with other DMSO cells + remaining_needed = target_size - len(dmso_coords) + other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster] + other_dmso_coords = other_dmso_data[pc_cols].values + + if len(other_dmso_coords) >= remaining_needed: + # Select closest to centroid from other DMSO cells + other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed) + x0_coords = np.vstack([dmso_coords, other_selected]) + else: + # Use all available DMSO cells and reduce target size + all_dmso_coords = dmso_data[pc_cols].values + target_size = min(target_size, len(all_dmso_coords)) + x0_coords = select_closest_to_centroid(all_dmso_coords, target_size) + + # Re-select endpoint clusters with updated target size + x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size) + x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size) + x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size) + + self.n_samples = target_size + + # for plotting + + + x0 = torch.tensor(x0_coords, dtype=torch.float32) + x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32) + x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32) + x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32) + x1 = torch.cat([x1_1, x1_2, x1_3], dim=0) + + self.coords_t0 = x0 + self.coords_t1 = x1 + + self.time_labels = np.concatenate([ + np.zeros(len(self.coords_t0)), # t=0 + np.ones(len(self.coords_t1)), # t=1 + ]) + + split_index = int(target_size * self.split_ratios[0]) + + if target_size - split_index < self.batch_size: + split_index = target_size - self.batch_size + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + train_x1 = x1_1[:split_index] + val_x1 = x1_1[split_index:] + + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0) + + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0) + + # Updated train dataloaders to include x1_3 + self.train_dataloaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + # Updated val dataloaders to include x1_3 + self.val_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + + all_coords = df[pc_cols].dropna().values.astype(float) + self.dataset = torch.tensor(all_coords, dtype=torch.float32) + self.tree = cKDTree(all_coords) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + # Updated metric samples - now using 4 clusters instead of 3 + #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy()) + km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset[:, :3].numpy()) + + cluster_labels = km_all.labels_ + + cluster_0_mask = cluster_labels == 0 + cluster_1_mask = cluster_labels == 1 + + samples = self.dataset.cpu().numpy() + + cluster_0_data = samples[cluster_0_mask] + cluster_1_data = samples[cluster_1_mask] + + self.metric_samples_dataloaders = [ + DataLoader( + torch.tensor(cluster_1_data, dtype=torch.float32), + batch_size=cluster_1_data.shape[0], + shuffle=False, + drop_last=False, + ), + DataLoader( + torch.tensor(cluster_0_data, dtype=torch.float32), + batch_size=cluster_0_data.shape[0], + shuffle=False, + drop_last=False, + ), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader( + self.metric_samples_dataloaders, mode="min_size" + ), + } + + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting""" + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + """ + Apply local smoothing based on k-nearest neighbors in the full dataset + This replaces the plane projection for 2D manifold regularization + """ + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] # Shape: (batch_size, k, 2) + + # Compute weighted average of neighbors + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + + # Weighted average of neighbors + smoothed = (weights * nearest_pts).sum(dim=1) + + # Blend original point with smoothed version + alpha = 0.3 # How much smoothing to apply + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + """Return data organized by timepoints for visualization""" + return { + 't0': self.coords_t0, + 't1': self.coords_t1, + 'time_labels': self.time_labels + } diff --git a/dataloaders/veres_leiden_data.py b/dataloaders/veres_leiden_data.py new file mode 100755 index 0000000000000000000000000000000000000000..6e41f2514d0c77594a966675d2b140ec371e4499 --- /dev/null +++ b/dataloaders/veres_leiden_data.py @@ -0,0 +1,317 @@ +import torch +import sys +from sklearn.preprocessing import StandardScaler +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from lightning.pytorch.utilities.combined_loader import CombinedLoader +import numpy as np +from scipy.spatial import cKDTree +import math +from functools import partial +import matplotlib.pyplot as plt +import pandas as pd +from torch.utils.data import TensorDataset +from sklearn.neighbors import kneighbors_graph +import igraph as ig +from leidenalg import find_partition, ModularityVertexPartition + +class WeightedBranchedVeresDataModule(pl.LightningDataModule): + + def __init__(self, args): + super().__init__() + self.save_hyperparameters() + + self.data_path = args.data_path + self.batch_size = args.batch_size + self.max_dim = args.dim + self.whiten = args.whiten + self.k = 20 + self.num_timesteps = 8 + # initial placeholder, will be set by clustering result + self.num_branches = args.branches if hasattr(args, 'branches') else None + self.split_ratios = args.split_ratios + self.metric_clusters = args.metric_clusters + self.discard_small = args.discard if hasattr(args, 'discard') else False + self.args = args + self._prepare_data() + + def _prepare_data(self): + print("Preparing Veres cell data with Leiden clustering in WeightedBranchedVeresLeidenDataModule") + df = pd.read_csv(self.data_path) + + # Build dictionary of coordinates by time + coords_by_t = { + t: df[df["samples"] == t].iloc[:, 1:].values # Skip 'samples' column + for t in sorted(df["samples"].unique()) + } + + n0 = coords_by_t[0].shape[0] + self.n_samples = n0 + + print("Timepoint distribution:") + for t in sorted(coords_by_t.keys()): + print(f" t={t}: {coords_by_t[t].shape[0]} points") + + # Leiden clustering on final timepoint + final_t = max(coords_by_t.keys()) + coords_final = coords_by_t[final_t] + k = 20 + knn_graph = kneighbors_graph(coords_final, k, mode='connectivity', include_self=False) + sources, targets = knn_graph.nonzero() + edgelist = list(zip(sources.tolist(), targets.tolist())) + graph = ig.Graph(edgelist, directed=False) + partition = find_partition(graph, ModularityVertexPartition) + leiden_labels = np.array(partition.membership) + n_leiden = len(np.unique(leiden_labels)) + print(f"Leiden found {n_leiden} clusters at t={final_t}") + + df_final = df[df["samples"] == final_t].copy() + df_final["branch"] = leiden_labels + + cluster_counts = df_final["branch"].value_counts().sort_index() + print(f"Branch distribution at t={final_t} (pre-merge):") + print(cluster_counts) + + # Merge small clusters to nearest large cluster (by centroid) + min_cells = 100 # threshold; adjust if needed + cluster_data_dict = {} + cluster_sizes = [] + for b in range(n_leiden): + branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values + cluster_data_dict[b] = branch_data + cluster_sizes.append(branch_data.shape[0]) + + large_clusters = [b for b, size in enumerate(cluster_sizes) if size >= min_cells] + small_clusters = [b for b, size in enumerate(cluster_sizes) if size < min_cells] + + # If no large cluster exists (all small), treat all clusters as large + if len(large_clusters) == 0: + large_clusters = list(range(n_leiden)) + small_clusters = [] + + if self.discard_small: + # Discard small clusters instead of merging + print(f"Discarding {len(small_clusters)} small clusters (< {min_cells} cells)") + # Keep only cells from large clusters + mask = np.isin(leiden_labels, large_clusters) + df_final = df_final[mask].copy() + merged_labels = leiden_labels[mask] + + # Remap to contiguous ids + new_ids = np.unique(merged_labels) + id_map = {old: new for new, old in enumerate(new_ids)} + merged_labels = np.array([id_map[x] for x in merged_labels]) + n_merged = len(np.unique(merged_labels)) + + df_final["branch"] = merged_labels + print(f"Kept {n_merged} large clusters") + else: + centroids = {b: np.mean(cluster_data_dict[b], axis=0) for b in range(n_leiden) if cluster_data_dict[b].shape[0] > 0} + + merged_labels = leiden_labels.copy() + for b in small_clusters: + if cluster_data_dict[b].shape[0] == 0: + continue + # find nearest large cluster + dists = [np.linalg.norm(centroids[b] - centroids[bl]) for bl in large_clusters] + nearest_large = large_clusters[int(np.argmin(dists))] + merged_labels[leiden_labels == b] = nearest_large + + # remap to contiguous ids + new_ids = np.unique(merged_labels) + id_map = {old: new for new, old in enumerate(new_ids)} + merged_labels = np.array([id_map[x] for x in merged_labels]) + n_merged = len(np.unique(merged_labels)) + + df_final["branch"] = merged_labels + print(f"Merged into {n_merged} clusters") + + cluster_counts_merged = df_final["branch"].value_counts().sort_index() + print(f"Branch distribution at t={final_t} (post-merge):") + print(cluster_counts_merged) + + endpoints = {} + cluster_sizes = [] + for b in range(n_merged): + branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values + cluster_sizes.append(branch_data.shape[0]) + replace = branch_data.shape[0] < n0 + sampled_indices = np.random.choice(branch_data.shape[0], size=n0, replace=replace) + endpoints[b] = branch_data[sampled_indices] + total_t_final = sum(cluster_sizes) + + x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) + self.coords_t0 = x0 + # intermediate timepoints + self.coords_intermediate = {t: torch.tensor(coords_by_t[t], dtype=torch.float32) + for t in coords_by_t.keys() if t != 0 and t != final_t} + + self.branch_endpoints = {b: torch.tensor(endpoints[b], dtype=torch.float32) for b in range(n_merged)} + self.num_branches = n_merged + + # time labels (for visualization) + time_labels_list = [np.zeros(len(self.coords_t0))] + for t in sorted(self.coords_intermediate.keys()): + time_labels_list.append(np.ones(len(self.coords_intermediate[t])) * t) + for b in range(self.num_branches): + time_labels_list.append(np.ones(len(self.branch_endpoints[b])) * final_t) + self.time_labels = np.concatenate(time_labels_list) + + # splits + split_index = int(n0 * self.split_ratios[0]) + if n0 - split_index < self.batch_size: + split_index = n0 - self.batch_size + + train_x0 = x0[:split_index] + val_x0 = x0[split_index:] + self.val_x0 = val_x0 + + train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0) + val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0) + + # branch weights proportional to cluster sizes + branch_weights = [size / total_t_final for size in cluster_sizes] + + # Split intermediate timepoints for sequential training support + train_intermediate = {} + val_intermediate = {} + self.train_coords_intermediate = {} # Store training-only intermediate data for MMD + for t in sorted(self.coords_intermediate.keys()): + coords_t = self.coords_intermediate[t] + train_coords_t = coords_t[:split_index] + val_coords_t = coords_t[split_index:] + train_weights_t = torch.full((train_coords_t.shape[0], 1), fill_value=1.0) + val_weights_t = torch.full((val_coords_t.shape[0], 1), fill_value=1.0) + train_intermediate[f"x{t}"] = (train_coords_t, train_weights_t) + val_intermediate[f"x{t}"] = (val_coords_t, val_weights_t) + self.train_coords_intermediate[t] = train_coords_t # Store training data by int key + + train_loaders = { + "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True), + } + val_loaders = { + "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True), + } + + # Add all intermediate timepoints to loaders + for t_key in sorted(train_intermediate.keys()): + train_coords_t, train_weights_t = train_intermediate[t_key] + val_coords_t, val_weights_t = val_intermediate[t_key] + train_loaders[t_key] = DataLoader( + TensorDataset(train_coords_t, train_weights_t), + batch_size=self.batch_size, + shuffle=True, + drop_last=True + ) + val_loaders[t_key] = DataLoader( + TensorDataset(val_coords_t, val_weights_t), + batch_size=self.batch_size, + shuffle=False, + drop_last=True + ) + + for b in range(self.num_branches): + # Calculate split based on this branch's size, not t=0 size + branch_size = self.branch_endpoints[b].shape[0] + branch_split_index = int(branch_size * self.split_ratios[0]) + if branch_size - branch_split_index < self.batch_size: + branch_split_index = max(0, branch_size - self.batch_size) + + train_branch = self.branch_endpoints[b][:branch_split_index] + val_branch = self.branch_endpoints[b][branch_split_index:] + train_branch_weights = torch.full((train_branch.shape[0], 1), fill_value=branch_weights[b]) + val_branch_weights = torch.full((val_branch.shape[0], 1), fill_value=branch_weights[b]) + train_loaders[f"x1_{b+1}"] = DataLoader( + TensorDataset(train_branch, train_branch_weights), + batch_size=self.batch_size, + shuffle=True, + drop_last=True + ) + val_loaders[f"x1_{b+1}"] = DataLoader( + TensorDataset(val_branch, val_branch_weights), + batch_size=self.batch_size, + shuffle=True, + drop_last=True + ) + + self.train_dataloaders = train_loaders + self.val_dataloaders = val_loaders + + # full dataset + all_data_list = [coords_by_t[t] for t in sorted(coords_by_t.keys())] + all_data = np.vstack(all_data_list) + self.dataset = torch.tensor(all_data, dtype=torch.float32) + self.tree = cKDTree(all_data) + + self.test_dataloaders = { + "x0": DataLoader(TensorDataset(self.val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False), + "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False), + } + + # Metric dataloaders: t0 vs (t1..t_final + endpoints) + cluster_0_data = self.coords_t0.cpu().numpy() + cluster_1_list = [self.coords_intermediate[t].cpu().numpy() for t in sorted(self.coords_intermediate.keys())] + cluster_1_list.extend([self.branch_endpoints[b].cpu().numpy() for b in range(self.num_branches)]) + cluster_1_data = np.vstack(cluster_1_list) + + self.metric_samples_dataloaders = [ + DataLoader(torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False), + DataLoader(torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False), + ] + + def train_dataloader(self): + combined_loaders = { + "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def val_dataloader(self): + combined_loaders = { + "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def test_dataloader(self): + combined_loaders = { + "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"), + "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"), + } + return CombinedLoader(combined_loaders, mode="max_size_cycle") + + def get_manifold_proj(self, points): + return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset) + + @staticmethod + def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3): + points_np = x.detach().cpu().numpy() + _, idx = tree.query(points_np, k=k) + nearest_pts = dataset[idx] + dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True) + weights = torch.exp(-dists / temp) + weights = weights / weights.sum(dim=1, keepdim=True) + smoothed = (weights * nearest_pts).sum(dim=1) + alpha = 0.3 + return (1 - alpha) * x + alpha * smoothed + + def get_timepoint_data(self): + result = { + 't0': self.coords_t0, + 'time_labels': self.time_labels + } + # intermediate timepoints + for t in sorted(self.coords_intermediate.keys()): + result[f't{t}'] = self.coords_intermediate[t] + final_t = max([0] + list(self.coords_intermediate.keys())) + 1 + for b in range(self.num_branches): + result[f't{final_t}_{b}'] = self.branch_endpoints[b] + return result + + def get_train_intermediate_data(self): + if hasattr(self, 'train_coords_intermediate'): + return self.train_coords_intermediate + else: + # Fallback to full intermediate data if train split not available + print("Warning: train_coords_intermediate not found, returning full intermediate data.") + return self.coords_intermediate diff --git a/environment.yml b/environment.yml new file mode 100755 index 0000000000000000000000000000000000000000..ad1bd02156a1c8eb24d4817830b9e401411de2a7 --- /dev/null +++ b/environment.yml @@ -0,0 +1,41 @@ +name: branchsbm +channels: + - conda-forge + - pytorch + - defaults +dependencies: + - conda-forge::python=3.10 + - conda-forge::openssl + - ca-certificates + - certifi + - pytorch::pytorch + - matplotlib + - pandas + - seaborn + - torchmetrics + - numpy>=1.26.0,<2.0.0 + - scikit-learn + - pyyaml + - jupyter + - ipykernel + - notebook + - tqdm + - pytorch-lightning>=2.0.0 + - lightning>=2.0.0 + - python-igraph + - leidenalg + - pip + - pip: + - scipy==1.13.1 + - wandb==0.22.1 + - torchcfm==1.0.7 + - torchdyn==1.0.6 + - torchdiffeq + - pot + - hydra-core + - omegaconf + - laspy + - umap-learn + - scanpy + - lpips + - geomloss \ No newline at end of file diff --git a/parsers.py b/parsers.py new file mode 100755 index 0000000000000000000000000000000000000000..2e1e4bc113d16ac1abe7acccdea8869a5af98b22 --- /dev/null +++ b/parsers.py @@ -0,0 +1,502 @@ +import argparse + +def parse_args(): + parser = argparse.ArgumentParser(description="Train BranchSBM") + + parser.add_argument("--seed", default=2, type=int) + + parser.add_argument( + "--config_path", type=str, + default='', + help="Path to config file" + ) + ####### ITERATES IN THE CODE ####### + parser.add_argument( + "--seeds", + nargs="+", + type=int, + default=[42, 43, 44, 45, 46], + help="Random seeds to iterate over", + ) + parser.add_argument( + "--t_exclude", + nargs="+", + type=int, + default=None, + help="Time points to exclude (iterating over)", + ) + #################################### + + parser.add_argument( + "--working_dir", + type=str, + default="path/to/your/home/BranchSBM", + help="Working directory", + ) + parser.add_argument( + "--resume_flow_model_ckpt", + type=str, + default=None, + help="Path to the flow model to resume training", + ) + parser.add_argument( + "--resume_growth_model_ckpt", + type=str, + default=None, + help="Path to the flow model to resume training", + ) + parser.add_argument( + "--load_geopath_model_ckpt", + type=str, + default=None, + help="Path to the geopath model to resume training", + ) + parser.add_argument( + "--sequential", + action=argparse.BooleanOptionalAction, + default=False, + help="Use sequential training for multi-timepoint data", + ) + parser.add_argument( + "--discard", + action=argparse.BooleanOptionalAction, + default=False, + help="Discard small clusters instead of merging them in Leiden clustering", + ) + parser.add_argument( + "--pseudo", + action=argparse.BooleanOptionalAction, + default=False, + help="Use pseudotime-based clustering for Weinreb data instead of Leiden on t=2", + ) + parser.add_argument( + "--branches", + type=int, + default=2, + help="Number of branches", + ) + parser.add_argument( + "--metric_clusters", + type=int, + default=3, + help="Number of metric clusters", + ) + parser.add_argument( + "--resolution", + type=float, + default=1.0, + help="Resolution parameter for Leiden clustering", + ) + + ######### DATASETS ################# + parser = datasets_parser(parser) + #################################### + + ######### IMAGE DATASETS ########### + parser = image_datasets_parser(parser) + #################################### + + ######### METRICS ################## + parser = metric_parser(parser) + #################################### + + ######### General Training ######### + parser = general_training_parser(parser) + #################################### + + ######### Training GeoPath Network #### + parser = geopath_network_parser(parser) + #################################### + + ######### Training Flow Network #### + parser = flow_network_parser(parser) + #################################### + + parser = growth_network_parser(parser) + + return parser.parse_args() + + +def datasets_parser(parser): + parser.add_argument("--dim", type=int, default=3, help="Dimension of data") + + parser.add_argument( + "--data_type", + type=str, + default="lidar", + help="Type of data, now wither scrna or one of toys", + ) + parser.add_argument( + "--data_path", + type=str, + default="", + help="lidar data path", + ) + parser.add_argument( + "--data_name", + type=str, + default="lidar", + help="Path to the dataset", + ) + parser.add_argument( + "--whiten", + action=argparse.BooleanOptionalAction, + default=True, + help="Whiten the data", + ) + parser.add_argument( + "--min_cells", + type=int, + default=500, + help="Minimum cells per cluster for Leiden clustering", + ) + parser.add_argument( + "--k", + type=int, + default=20, + help="Number of neighbors for KNN graph in Leiden clustering", + ) + parser.add_argument( + "--pseudotime_threshold", + type=float, + default=0.6, + help="Pseudotime threshold for terminal cells (only used when --pseudo is True)", + ) + parser.add_argument( + "--terminal_neighbors", + type=int, + default=20, + help="Number of neighbors for terminal cell clustering (only used when --pseudo is True)", + ) + parser.add_argument( + "--terminal_resolution", + type=float, + default=0.2, + help="Resolution for terminal cell Leiden clustering (only used when --pseudo is True)", + ) + parser.add_argument( + "--n_dcs", + type=int, + default=10, + help="Number of diffusion components for DPT (only used when --pseudo is True)", + ) + parser.add_argument( + "--initial_neighbors", + type=int, + default=30, + help="Number of neighbors for initial kNN graph (only used when --pseudo is True)", + ) + parser.add_argument( + "--initial_resolution", + type=float, + default=1.0, + help="Resolution for initial Leiden clustering (only used when --pseudo is True)", + ) + return parser + + +def image_datasets_parser(parser): + parser.add_argument( + "--image_size", + type=int, + default=128, + help="Size of the image", + ) + parser.add_argument( + "--x0_label", + type=str, + default="dog", + help="Label for x0", + ) + parser.add_argument( + "--x1_label", + type=str, + default="cat", + help="Label for x1", + ) + return parser + + +def metric_parser(parser): + parser.add_argument( + "--branchsbm", + action=argparse.BooleanOptionalAction, + default=True, + help="If branched SBM", + ) + parser.add_argument( + "--n_centers", + type=int, + default=100, + help="Number of centers for RBF network", + ) + parser.add_argument( + "--kappa", + type=float, + default=1.0, + help="Kappa parameter for RBF network", + ) + parser.add_argument( + "--rho", + type=float, + default=0.001, + help="Rho parameter in Riemanian Velocity Calculation", + ) + parser.add_argument( + "--velocity_metric", + type=str, + default="rbf", + help="Metric for velocity calculation", + ) + parser.add_argument( + "--gammas", + nargs="+", + type=float, + default=[0.2, 0.2], + help="Gamma parameter in Riemanian Velocity Calculation", + ) + + parser.add_argument( + "--metric_epochs", + type=int, + default=100, + help="Number of epochs for metric learning", + ) + parser.add_argument( + "--metric_patience", + type=int, + default=20, + help="Patience for metric learning", + ) + parser.add_argument( + "--metric_lr", + type=float, + default=1e-2, + help="Learning rate for metric learning", + ) + parser.add_argument( + "--alpha_metric", + type=float, + default=1.0, + help="Alpha parameter for metric learning", + ) + + return parser + + +def general_training_parser(parser): + parser.add_argument( + "--batch_size", type=int, default=128, help="Batch size for training" + ) + parser.add_argument( + "--optimal_transport_method", + type=str, + default="exact", + help="Use optimal transport in CFM training", + ) + parser.add_argument( + "--ema_decay", + type=float, + default=None, + help="Decay for EMA", + ) + parser.add_argument( + "--split_ratios", + nargs=2, + type=float, + default=[0.9, 0.1], + help="Split ratios for training/validation data in CFM training", + ) + parser.add_argument("--epochs", type=int, default=100, help="Number of epochs") + parser.add_argument( + "--accelerator", type=str, default="gpu", help="Training accelerator" + ) + parser.add_argument( + "--run_name", type=str, default=None, help="Name for the wandb run" + ) + parser.add_argument( + "--sim_num_steps", + type=int, + default=1000, + help="Number of steps in simulation", + ) + return parser + + +def geopath_network_parser(parser): + parser.add_argument( + "--manifold", + action=argparse.BooleanOptionalAction, + default=True, + help="If use data manifold metric", + ) + parser.add_argument( + "--patience_geopath", + type=int, + default=50, + help="Patience for training geopath model", + ) + parser.add_argument( + "--hidden_dims_geopath", + nargs="+", + type=int, + default=[64, 64, 64], + help="Dimensions of hidden layers for GeoPath model training", + ) + parser.add_argument( + "--time_geopath", + action=argparse.BooleanOptionalAction, + default=False, + help="Use time in GeoPath model", + ) + parser.add_argument( + "--activation_geopath", + type=str, + default="selu", + help="Activation function for GeoPath", + ) + parser.add_argument( + "--geopath_optimizer", + type=str, + default="adam", + help="Optimizer for GeoPath training", + ) + parser.add_argument( + "--geopath_lr", + type=float, + default=1e-4, + help="Learning rate for GeoPath training", + ) + parser.add_argument( + "--geopath_weight_decay", + type=float, + default=1e-5, + help="Weight decay for GeoPath training", + ) + parser.add_argument( + "--mmd_weight", + type=float, + default=0.1, + help="Weight for MMD loss at intermediate timepoints (only used when >2 timepoints)", + ) + return parser + + +def flow_network_parser(parser): + parser.add_argument( + "--sigma", type=float, default=0.1, help="Sigma parameter for CFM (variance)" + ) + parser.add_argument( + "--patience", + type=int, + default=5, + help="Patience for early stopping in CFM training", + ) + parser.add_argument( + "--hidden_dims_flow", + nargs="+", + type=int, + default=[64, 64, 64], + help="Dimensions of hidden layers for CFM training", + ) + parser.add_argument( + "--check_val_every_n_epoch", + type=int, + default=10, + help="Check validation every N epochs during CFM training", + ) + parser.add_argument( + "--activation_flow", + type=str, + default="selu", + help="Activation function for CFM", + ) + parser.add_argument( + "--flow_optimizer", + type=str, + default="adamw", + help="Optimizer for GeoPath training", + ) + parser.add_argument( + "--flow_lr", + type=float, + default=1e-3, + help="Learning rate for GeoPath training", + ) + parser.add_argument( + "--flow_weight_decay", + type=float, + default=1e-5, + help="Weight decay for GeoPath training", + ) + return parser + +def growth_network_parser(parser): + parser.add_argument( + "--patience_growth", + type=int, + default=5, + help="Patience for early stopping in CFM training", + ) + parser.add_argument( + "--time_growth", + action=argparse.BooleanOptionalAction, + default=False, + help="Use time in GeoPath model", + ) + parser.add_argument( + "--hidden_dims_growth", + nargs="+", + type=int, + default=[64, 64, 64], + help="Dimensions of hidden layers for growth net training", + ) + parser.add_argument( + "--activation_growth", + type=str, + default="tanh", + help="Activation function for CFM", + ) + parser.add_argument( + "--growth_optimizer", + type=str, + default="adamw", + help="Optimizer for GeoPath training", + ) + parser.add_argument( + "--growth_lr", + type=float, + default=1e-3, + help="Learning rate for GeoPath training", + ) + parser.add_argument( + "--growth_weight_decay", + type=float, + default=1e-5, + help="Weight decay for GeoPath training", + ) + parser.add_argument( + "--lambda_energy", + type=float, + default=1.0, + help="Weight for energy loss", + ) + parser.add_argument( + "--lambda_mass", + type=float, + default=100.0, + help="Weight for mass loss", + ) + parser.add_argument( + "--lambda_match", + type=float, + default=1000.0, + help="Weight for matching loss", + ) + parser.add_argument( + "--lambda_recons", + type=float, + default=1.0, + help="Weight for reconstruction loss", + ) + return parser \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md new file mode 100755 index 0000000000000000000000000000000000000000..596c51d49acbe2aa36ec95515ff4777b58a1ebef --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,226 @@ +# Running Experiments with BranchSBM 🌳🧬 + +This directory contains training scripts for all experiments with BranchSBM, including LiDAR navigation 🗻, simulating cell differentiation 🧫, and cell state perturbation modelling 🧬. This codebase contains code from the [Metric Flow Matching repo](https://github.com/kkapusniak/metric-flow-matching) ([Kapusniak et al. 2024](https://arxiv.org/abs/2405.14780)). + +## Environment Installation +``` +conda env create -f environment.yml + +conda activate branchsbm +``` + +## Data +LiDAR data is taken from the [Generalized Schrödinger Bridge Matching repo](https://github.com/facebookresearch/generalized-schrodinger-bridge-matching) and Mouse Hematopoesis is taken from the [DeepRUOT repo](https://github.com/zhenyiizhang/DeepRUOT) + +We use perturbation data from the [Tahoe-100M dataset](https://huggingface.co/datasets/tahoebio/Tahoe-100M) containing control DMSO-treated cell data and perturbed cell data. + +The raw data contains a total of 60K genes. We select the top 2000 highly variable genes (HVGs) and perform principal component analysis (PCA), to maximally capture the variance in the data via the top principal components (38% in the top-50 PCs). **Our goal is to learn the dynamic trajectories that map control cell clusters to the perturbd cell clusters.** + +**Specifically, we model the following perturbations**: + +1. **Clonidine**: Cell states under 5uM Clonidine perturbation at various PC dimensions (50D, 100D, 150D) with 1 unseen population. +2. **Trametinib**: Cell states under 5uM Trametinib perturbation (50D) with 2 unseen populations. + +All data files are stored in: +``` +BranchSBMl/data/ +├── rainier2-thin.las # LiDAR data +├── mouse_hematopoiesis.csv # Mouse Hematopoiesis data +├── pca_and_leiden_labels.csv # Clonidine data +├── Trametinib_5.0uM_pca_and_leidenumap_labels.csv # Trametinib data +└── Veres_alltime.csv # Pancreatic β-Cell data +``` + +## Running Experiments + +All training scripts are located in `BranchSBM/scripts/`. Each script is pre-configured for a specific experiment. + +The scripts for BranchSBM experiments include: + +- **`lidar.sh`** - LiDAR trajectory data with 2 branches +- **`mouse.sh`** - Mouse cell differentiation with 2 branches +- **`clonidine.sh`** - Clonidine perturbation with 2 branches +- **`trametinib.sh`** - Trametinib perturbation with 3 branches +- **`veres.sh`** - Pancreatic beta-cell differentiation with 11 branches + + +The scripts for the baseline single-branch SBM experiments include: + +- **`mouse_single.sh`** - Mouse single branch +- **`clonidine_single.sh`** - Clonidine single branch +- **`trametinib_single.sh`** - Trametinib single branch +- **`lidar_single.sh`** - LiDAR single branch + +**Before running experiments:** + +1. Set `HOME_LOC` to the base path where BranchSBM is located and `ENV_PATH` to the directory where your environment is downloaded in the `.sh` files in `scripts/` +2. Create a path `BranchSBM/results` where the simulated trajectory figures and metrics will be saved. Also, create `BranchSBM/logs` where the training logs will be saved. +3. Activate the conda environment: +``` +conda activate branchsbm +``` +4. Login to wandb using `wandb login` + +**Run experiment using `nohup` with the following commands:** + +``` +cd scripts + +chmod lidar.sh + +nohup ./lidar.sh > lidar.log 2>&1 & +``` + +Evaluation will run automatically after the specified number of rollouts `--num_rollouts` is finished. To see metrics, go to `results//metrics/` or the end of `logs/.log`. + +For Clonidine, `x1_1` indicates the cell cluster that is sampled from for training and `x1_2` is the held-out cell cluster. For Trametinib `x1_1` indicates the cell cluster that is sampled from for training and `x1_2` and `x1_3` are the held-out cell clusters. + +We report the following metrics for each of the clusters in our paper: +1. Maximum Mean Discrepancy (RBF-MMD) of simualted cell cluster with target cell cluster (same cell count). +2. 1-Wasserstein and 2-Wasserstein distances against full cell population in the cluster. + +## Overview of Outputs + +**Training outputs are saved to experiment-specific directories:** + +``` +BranchSBM/results/ +├── _clonidine50D_branched/ +│ └── figures/ # Figures of simulated +│ └── metrics.csv # JSON of metrics +``` + +**PyTorch Lightning automatically saves model checkpoints to:** + +``` +BranchSBM/scripts/lightning_logs/ +├── / +│ ├── checkpoints/ +│ │ ├── epoch=N-step=M.ckpt # Checkpoint +``` + +**Training logs are saved in:** +``` +entangled-cell/logs/ +├── _lidar_single_train.log +├── _lidar_train.log +├── _mouse_single_train.log +├── _mouse_train.log +├── _clonidine_single_train.log +├── _clonidine50D_train.log +├── _clonidine100D_train.log +├── _clonidine150D_train.log +├── _trametinib_single_train.log +├── _trametinib_train.log +└── _veres_train.log +``` + +## Available Experiments + +### Branched Experiments (Multi-branch trajectories) + +These experiments model cell differentiation or perturbation with multiple branches: + +- **`mouse.sh`** - Mouse cell differentiation with 2 branches (GPU 0) +- **`trametinib.sh`** - Trametinib perturbation with 3 branches (GPU 1) +- **`lidar.sh`** - LiDAR trajectory data with 2 branches (GPU 2) +- **`clonidine.sh`** - Clonidine perturbation with 2 branches (GPU 3) + +### Single-Branch Experiments (Control/baseline) + +These are baseline experiments with single trajectories: + +- **`mouse_single.sh`** - Mouse single trajectory (GPU 4) +- **`clonidine_single.sh`** - Clonidine single trajectory (GPU 5) +- **`trametinib_single.sh`** - Trametinib single trajectory (GPU 6) +- **`lidar_single.sh`** - LiDAR single trajectory (GPU 7) + +## Running Scripts + +### Run a single experiment + +From the `scripts/` directory: + +```bash +cd scripts +chmod +x mouse.sh +nohup ./mouse.sh > mouse.log 2>&1 & +``` + +### Run all branched experiments in parallel + +```bash +nohup ./mouse.sh > mouse.log 2>&1 & +nohup ./trametinib.sh > trametinib.log 2>&1 & +nohup ./lidar.sh > lidar.log 2>&1 & +nohup ./clonidine.sh > clonidine.log 2>&1 & +``` + +### Run all single-branch experiments in parallel + +```bash +nohup ./mouse_single.sh > mouse_single.log 2>&1 & +nohup ./clonidine_single.sh > clonidine_single.log 2>&1 & +nohup ./trametinib_single.sh > trametinib_single.log 2>&1 & +nohup ./lidar_single.sh > lidar_single.log 2>&1 & +``` + +### Run all experiments simultaneously + +Each script is assigned to a different GPU, so you can run all 8 in parallel: + +```bash +nohup ./mouse.sh > mouse.log 2>&1 & +nohup ./trametinib.sh > trametinib.log 2>&1 & +nohup ./lidar.sh > lidar.log 2>&1 & +nohup ./clonidine.sh > clonidine.log 2>&1 & +nohup ./mouse_single.sh > mouse_single.log 2>&1 & +nohup ./clonidine_single.sh > clonidine_single.log 2>&1 & +nohup ./trametinib_single.sh > trametinib_single.log 2>&1 & +nohup ./lidar_single.sh > lidar_single.log 2>&1 & +``` + +## Monitoring Training + +Logs are saved in `./BranchSBM/logs/` with format `MM_DD__train.log`. + +Each experiment logs to wandb with a unique run name: +- Branched experiments: `_branched` (e.g., `mouse_branched`) +- Single experiments: `_single` (e.g., `mouse_single`) + +Visit your wandb dashboard to view training progress in real-time. + +## Training Parameters + +Default training parameters for each experiment: + +| Parameter | LiDAR | Mouse Hematopoiesis scRNA | Clonidine (50 PCs) | Clonidine (100 PCs) | Clonidine (150 PCs) | Trametinib | Pancreatic β-Cell | +|---|---|---|---|---|---|---|---| +| branches | 2 | 2 | 2 | 2 | 2 | 3 | 11 | +| data dimension | 3 | 2 | 50 | 100 | 150 | 50 | 30 | +| batch size | 128 | 128 | 32 | 32 | 32 | 32 | 256 | +| λ_energy | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | +| λ_mass | 100 | 100 | 100 | 100 | 100 | 100 | 100 | +| λ_match | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | +| λ_recons | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | +| λ_growth | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | +| V_t | LAND | LAND | RBF | RBF | RBF | RBF | RBF | +| RBF N_c | - | - | 150 | 300 | 300 | 150 | 300 | +| RBF κ | - | - | 1.5 | 2.0 | 3.0 | 1.5 | 3.0 | +| hidden dimension | 64 | 64 | 1024 | 1024 | 1024 | 1024 | 1024 | +| lr interpolant | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | +| lr velocity | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | +| lr growth | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | + +To modify parameters, edit the corresponding `.sh` file. + +## Training Pipeline + +Each experiment runs through 4 stages: + +1. **Stage 1: Geopath** - Train geodesic path interpolants +2. **Stage 2: Flow Matching** - Train continuous normalizing flows +3. **Stage 3: Growth** - Train growth networks for branches +4. **Stage 4: Joint** - Joint training of all components + +Checkpoints are saved automatically and loaded between stages. \ No newline at end of file diff --git a/scripts/clonidine100.sh b/scripts/clonidine100.sh new file mode 100755 index 0000000000000000000000000000000000000000..2e4101d583d9216c32a1a694513d214ec5b39444 --- /dev/null +++ b/scripts/clonidine100.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='clonidine100D_branched' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=4 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --config_path "$SCRIPT_LOC/configs/clonidine_100D.yaml" \ + --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/clonidine150.sh b/scripts/clonidine150.sh new file mode 100755 index 0000000000000000000000000000000000000000..c5ba510169c24110d3955a143328ce192f5b0155 --- /dev/null +++ b/scripts/clonidine150.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='clonidine150D_branched' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=5 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --config_path "$SCRIPT_LOC/configs/clonidine_150D.yaml" \ + --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/clonidine50.sh b/scripts/clonidine50.sh new file mode 100755 index 0000000000000000000000000000000000000000..965efc36f5fd0430e6ff6c2fa16a3053ea2a6110 --- /dev/null +++ b/scripts/clonidine50.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='clonidine50D_branched' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=3 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name ${DATE}_${SPECIAL_PREFIX} \ + --config_path "$SCRIPT_LOC/configs/clonidine_50D.yaml" \ + --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/clonidine50_single.sh b/scripts/clonidine50_single.sh new file mode 100755 index 0000000000000000000000000000000000000000..43ae4ebef7867ee2e6abc4186423deb41e596882 --- /dev/null +++ b/scripts/clonidine50_single.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='clonidine_single' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=3 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name "clonidine50D_single" \ + --config_path "$SCRIPT_LOC/configs/clonidine_50Dsingle.yaml" \ + --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/lidar.sh b/scripts/lidar.sh new file mode 100755 index 0000000000000000000000000000000000000000..3750c5b1e789afcaeff9c2816ca5d82fab3b23cb --- /dev/null +++ b/scripts/lidar.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='lidar_branched' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=2 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --config_path "$SCRIPT_LOC/configs/lidar.yaml" \ + --epochs 10 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --batch_size 128 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/lidar_single.sh b/scripts/lidar_single.sh new file mode 100755 index 0000000000000000000000000000000000000000..76a4660c6da5879af18ef9c671fbb16b4651953a --- /dev/null +++ b/scripts/lidar_single.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='lidar_single' +# set 3 have skip connection +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=2 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --config_path "$SCRIPT_LOC/configs/lidar_single.yaml" \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --epochs 100 \ + --batch_size 128 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/mouse.sh b/scripts/mouse.sh new file mode 100755 index 0000000000000000000000000000000000000000..d68f2661b37bcd01355f2773db7c0f5cca8de27b --- /dev/null +++ b/scripts/mouse.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='mouse_branched' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=1 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --config_path "$SCRIPT_LOC/configs/mouse.yaml" \ + --epochs 100 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/mouse_single.sh b/scripts/mouse_single.sh new file mode 100755 index 0000000000000000000000000000000000000000..eafce9d44cdb57ed82c715ac14f698cda53a3913 --- /dev/null +++ b/scripts/mouse_single.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='mouse_single' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=1 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --config_path "$SCRIPT_LOC/configs/mouse_single.yaml" >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/trametinib.sh b/scripts/trametinib.sh new file mode 100755 index 0000000000000000000000000000000000000000..e7f82a0ba33544b85d2828ed9bd80412f2b3559d --- /dev/null +++ b/scripts/trametinib.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='trametinib_branched' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=6 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --config_path "$SCRIPT_LOC/configs/trametinib.yaml" \ + --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/trametinib_single.sh b/scripts/trametinib_single.sh new file mode 100755 index 0000000000000000000000000000000000000000..e2e74096c97e0a73223554cff28427d52e85901a --- /dev/null +++ b/scripts/trametinib_single.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='trametinib_single' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=6 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name "${DATE}_${SPECIAL_PREFIX}" \ + --config_path "$SCRIPT_LOC/configs/trametinib_single.yaml" \ + --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate \ No newline at end of file diff --git a/scripts/veres.sh b/scripts/veres.sh new file mode 100755 index 0000000000000000000000000000000000000000..36cacb158b498575368672d860df0770c41d45be --- /dev/null +++ b/scripts/veres.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +HOME_LOC=/path/to/your/home/BranchSBM +ENV_LOC=/path/to/your/envs/branchsbm +SCRIPT_LOC=$HOME_LOC +LOG_LOC=$HOME_LOC/logs +DATE=$(date +%m_%d) +SPECIAL_PREFIX='veres' +PYTHON_EXECUTABLE=$ENV_LOC/bin/python + +# Set GPU device +export CUDA_VISIBLE_DEVICES=7 + +# =================================================================== +source "$(conda info --base)/etc/profile.d/conda.sh" +conda activate $ENV_LOC + +cd $HOME_LOC + +$PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \ + --epochs 100 \ + --run_name ${DATE}_${SPECIAL_PREFIX} \ + --min_cells 100 \ + --config $SCRIPT_LOC/configs/veres.yaml >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1 + +conda deactivate diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100755 index 0000000000000000000000000000000000000000..8d54232deece350edf433991a4b17184ed4ce20e Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/branch_flow_net_test.py b/src/branch_flow_net_test.py new file mode 100755 index 0000000000000000000000000000000000000000..4ae5b689cb52221618676d27cacdccb5ad451745 --- /dev/null +++ b/src/branch_flow_net_test.py @@ -0,0 +1,1791 @@ +""" +Separate test classes for each BranchSBM experiment with specific plotting styles. +Each class handles testing and visualization for: LiDAR, Mouse, Clonidine, Trametinib, Veres. +""" + +import os +import json +import csv +import torch +import numpy as np +import matplotlib.pyplot as plt +import pytorch_lightning as pl +import random +import ot +from torchdyn.core import NeuralODE +from matplotlib.colors import LinearSegmentedColormap +from matplotlib.collections import LineCollection +from .networks.utils import flow_model_torch_wrapper +from .branch_flow_net_train import BranchFlowNetTrainBase +from .branch_growth_net_train import GrowthNetTrain +from .utils import wasserstein, mix_rbf_mmd2, plot_lidar +import json + +def evaluate_model(gt_data, model_data, a, b): + # ensure inputs are tensors + if not isinstance(gt_data, torch.Tensor): + gt_data = torch.tensor(gt_data, dtype=torch.float32) + if not isinstance(model_data, torch.Tensor): + model_data = torch.tensor(model_data, dtype=torch.float32) + + # choose device: prefer model_data's device if it's not CPU, otherwise use gt_data's device + try: + model_dev = model_data.device + except Exception: + model_dev = torch.device('cpu') + try: + gt_dev = gt_data.device + except Exception: + gt_dev = torch.device('cpu') + + device = model_dev if model_dev.type != 'cpu' else gt_dev + + gt = gt_data.to(device=device, dtype=torch.float32) + md = model_data.to(device=device, dtype=torch.float32) + + M = torch.cdist(gt, md, p=2).cpu().numpy() + if np.isnan(M).any() or np.isinf(M).any(): + return np.nan + return ot.emd2(a, b, M, numItermax=1e7) + +def compute_distribution_distances(pred, true, pred_full=None, true_full=None): + w1 = wasserstein(pred, true, power=1) + w2 = wasserstein(pred, true, power=2) + + # Use full dimensions for MMD if provided, otherwise use same as W1/W2 + mmd_pred = pred_full if pred_full is not None else pred + mmd_true = true_full if true_full is not None else true + + # MMD requires same number of samples — randomly subsample the larger set + n_pred, n_true = mmd_pred.shape[0], mmd_true.shape[0] + if n_pred > n_true: + perm = torch.randperm(n_pred)[:n_true] + mmd_pred = mmd_pred[perm] + elif n_true > n_pred: + perm = torch.randperm(n_true)[:n_pred] + mmd_true = mmd_true[perm] + mmd = mix_rbf_mmd2(mmd_pred, mmd_true, sigma_list=[0.01, 0.1, 1, 10, 100]).item() + + return {"W1": w1, "W2": w2, "MMD": mmd} + + +def compute_tmv_from_mass_over_time(mass_over_time, all_endpoints, time_points=None, timepoint_data=None, time_index=None, target_time=None, gt_key_template='t1_{}', weights_over_time=None): + + if weights_over_time is not None or mass_over_time is not None: + if time_index is None: + if target_time is not None and time_points is not None: + arr = np.array(time_points) + time_index = int(np.argmin(np.abs(arr - float(target_time)))) + else: + # default to last index + ref_list = weights_over_time if weights_over_time is not None else mass_over_time + time_index = len(ref_list[0]) - 1 + else: + # neither available; time_index not used + if time_index is None: + time_index = -1 + + n_branches = len(all_endpoints) + + # initial total cells for normalization + n_initial = None + if timepoint_data is not None and 't0' in timepoint_data: + try: + n_initial = int(timepoint_data['t0'].shape[0]) + except Exception: + n_initial = None + + pred_masses = [] + for i in range(n_branches): + # Use sum of actual particle weights if available, otherwise mean_weight * num_particles + if weights_over_time is not None: + try: + weights_tensor = weights_over_time[i][time_index] + # Sum all particle weights to get total mass for this branch + total_mass = float(weights_tensor.sum().item()) + pred_masses.append(total_mass) + continue + except Exception: + pass # Fall through to mean weight calculation + + # Fallback: mean weight from mass_over_time if available, otherwise assume weight=1 + mean_w = 1.0 + if mass_over_time is not None: + try: + mean_w = float(mass_over_time[i][time_index]) + except Exception: + mean_w = 1.0 + + # determine number of particles for this branch + num_particles = 0 + try: + if hasattr(all_endpoints[i], 'shape'): + num_particles = int(all_endpoints[i].shape[0]) + else: + num_particles = int(len(all_endpoints[i])) + except Exception: + num_particles = 0 + + pred_masses.append(mean_w * float(num_particles)) + + # ground-truth masses per branch + gt_masses = [] + if timepoint_data is not None: + for i in range(n_branches): + key1 = gt_key_template.format(i) + if key1 in timepoint_data: + gt_masses.append(float(timepoint_data[key1].shape[0])) + else: + base_key = gt_key_template.split("_")[0] if '_' in gt_key_template else gt_key_template + if base_key in timepoint_data: + gt_masses.append(float(timepoint_data[base_key].shape[0])) + else: + gt_masses.append(0.0) + else: + gt_masses = [0.0 for _ in range(n_branches)] + + # determine normalization denominator + if n_initial is None: + s = float(sum(gt_masses)) + if s > 0: + n_initial = s + else: + n_initial = float(sum(pred_masses)) if sum(pred_masses) > 0 else 1.0 + + pred_fracs = [m / float(n_initial) for m in pred_masses] + gt_fracs = [m / float(n_initial) for m in gt_masses] + + tmv = 0.5 * float(np.sum(np.abs(np.array(pred_fracs) - np.array(gt_fracs)))) + + return { + 'time_index': time_index, + 'pred_masses': pred_masses, + 'gt_masses': gt_masses, + 'pred_fracs': pred_fracs, + 'gt_fracs': gt_fracs, + 'tmv': tmv, + } + + +class FlowNetTestLidar(GrowthNetTrain): + + def test_step(self, batch, batch_idx): + # Unwrap CombinedLoader outer tuple if needed + if isinstance(batch, (list, tuple)) and len(batch) == 1: + batch = batch[0] + + if isinstance(batch, dict) and "test_samples" in batch: + test_samples = batch["test_samples"] + metric_samples = batch["metric_samples"] + + if isinstance(test_samples, (list, tuple)) and len(test_samples) >= 2 and isinstance(test_samples[-1], int): + test_samples = test_samples[0] + if isinstance(metric_samples, (list, tuple)) and len(metric_samples) >= 2 and isinstance(metric_samples[-1], int): + metric_samples = metric_samples[0] + + if isinstance(test_samples, (list, tuple)) and len(test_samples) == 1: + test_samples = test_samples[0] + main_batch = test_samples + + if isinstance(metric_samples, dict): + metric_batch = list(metric_samples.values()) + elif isinstance(metric_samples, (list, tuple)): + metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples] + else: + metric_batch = [metric_samples] + elif isinstance(batch, (list, tuple)) and len(batch) == 2: + # Old tuple format: (test_samples, metric_samples) + # Each could be dict or list + test_samples = batch[0] + metric_samples = batch[1] + + if isinstance(test_samples, dict): + main_batch = test_samples + elif isinstance(test_samples, (list, tuple)): + main_batch = test_samples[0] + else: + main_batch = test_samples + + if isinstance(metric_samples, dict): + metric_batch = list(metric_samples.values()) + elif isinstance(metric_samples, (list, tuple)): + metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples] + else: + metric_batch = [metric_samples] + else: + # Fallback + main_batch = batch + metric_batch = [] + + timepoint_data = self.trainer.datamodule.get_timepoint_data() + # main_batch is a dict like {"x0": (tensor, weights), ...} + if isinstance(main_batch, dict): + device = main_batch["x0"][0].device + else: + device = main_batch[0]["x0"][0].device + + x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) + w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device) + full_batch = {"x0": (x0_all, w0_all)} + + time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch) + + cloud_points = main_batch["dataset"][0] # [N, 3] + + # Run 5 trials with random subsampling for robust metrics + n_trials = 5 + + # Compute per-branch metrics + metrics_dict = {} + for i, endpoints in enumerate(all_endpoints): + true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' + true_data = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(endpoints.device) + + w1_br, w2_br, mmd_br = [], [], [] + for trial in range(n_trials): + n_min = min(endpoints.shape[0], true_data.shape[0]) + perm_pred = torch.randperm(endpoints.shape[0])[:n_min] + perm_gt = torch.randperm(true_data.shape[0])[:n_min] + m = compute_distribution_distances( + endpoints[perm_pred, :2], true_data[perm_gt, :2], + pred_full=endpoints[perm_pred], true_full=true_data[perm_gt] + ) + w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) + + metrics_dict[f"branch_{i+1}"] = { + "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), + "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), + "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), + } + self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True) + print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " + f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") + + # Compute combined metrics across all branches (5 trials) + all_pred_combined = torch.cat(list(all_endpoints), dim=0) + all_true_list = [] + for i in range(len(all_endpoints)): + true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' + all_true_list.append(torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(all_pred_combined.device)) + all_true_combined = torch.cat(all_true_list, dim=0) + + w1_trials, w2_trials, mmd_trials = [], [], [] + for trial in range(n_trials): + n_min = min(all_pred_combined.shape[0], all_true_combined.shape[0]) + perm_pred = torch.randperm(all_pred_combined.shape[0])[:n_min] + perm_gt = torch.randperm(all_true_combined.shape[0])[:n_min] + m = compute_distribution_distances( + all_pred_combined[perm_pred, :2], all_true_combined[perm_gt, :2], + pred_full=all_pred_combined[perm_pred], true_full=all_true_combined[perm_gt] + ) + w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) + + w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) + w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) + mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) + self.log("test/W1_combined", w1_mean, on_epoch=True) + self.log("test/W2_combined", w2_mean, on_epoch=True) + self.log("test/MMD_combined", mmd_mean, on_epoch=True) + + metrics_dict["combined"] = { + "W1_mean": float(w1_mean), "W1_std": float(w1_std), + "W2_mean": float(w2_mean), "W2_std": float(w2_std), + "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), + "n_trials": n_trials, + } + print(f"\n=== Combined ===") + print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") + print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") + print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") + + # Inverse-transform cloud points for visualization + if self.whiten: + cloud_points = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform( + cloud_points.cpu().detach().numpy() + ) + ) + + # Create results directory structure + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + figures_dir = f'{results_dir}/figures' + os.makedirs(figures_dir, exist_ok=True) + + # Save metrics to JSON + metrics_path = f'{results_dir}/metrics.json' + with open(metrics_path, 'w') as f: + json.dump(metrics_dict, f, indent=2) + print(f"Metrics saved to {metrics_path}") + + # Save detailed per-branch metrics to CSV + detailed_csv_path = f'{results_dir}/metrics_detailed.csv' + with open(detailed_csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) + for key in sorted(metrics_dict.keys()): + m = metrics_dict[key] + writer.writerow([key, + f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}', + f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}', + f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}']) + print(f"Detailed metrics CSV saved to {detailed_csv_path}") + + # Convert all_trajs from list of lists to stacked tensors for plotting + # all_trajs[i] is a list of T tensors of shape [B, D] + # Stack to get shape [B, T, D] + stacked_trajs = [] + for traj_list in all_trajs: + # Stack along time dimension (dim=1) to get [B, T, D] + stacked_traj = torch.stack(traj_list, dim=1) + stacked_trajs.append(stacked_traj) + + # Inverse-transform trajectories to match cloud_points coordinates + if self.whiten: + stacked_trajs_original = [] + for traj in stacked_trajs: + B, T, D = traj.shape + # Reshape to [B*T, D] for inverse transform + traj_flat = traj.reshape(-1, D).cpu().detach().numpy() + traj_inv = self.trainer.datamodule.scaler.inverse_transform(traj_flat) + # Reshape back to [B, T, D] + traj_inv = torch.tensor(traj_inv).reshape(B, T, D) + stacked_trajs_original.append(traj_inv) + stacked_trajs = stacked_trajs_original + + # ===== Plot all branches together ===== + fig = plt.figure(figsize=(10, 8)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + for i, traj in enumerate(stacked_trajs): + plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) + plt.savefig(f'{figures_dir}/{self.args.data_name}_all_branches.png', dpi=300) + plt.close() + + # ===== Plot each branch separately ===== + for i, traj in enumerate(stacked_trajs): + fig = plt.figure(figsize=(10, 8)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) + plt.savefig(f'{figures_dir}/{self.args.data_name}_branch_{i + 1}.png', dpi=300) + plt.close() + + print(f"LiDAR figures saved to {figures_dir}") + + +class FlowNetTestMouse(GrowthNetTrain): + + def test_step(self, batch, batch_idx): + # Handle both tuple and dict batch formats from CombinedLoader + if isinstance(batch, dict): + main_batch = batch.get("test_samples", batch) + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + elif isinstance(batch, (list, tuple)) and len(batch) >= 1: + if isinstance(batch[0], dict): + main_batch = batch[0].get("test_samples", batch[0]) + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + else: + main_batch = batch[0][0] + else: + main_batch = batch + + device = main_batch["x0"][0].device + + # Use val x0 as initial conditions + x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) + + # Get timepoint data for ground truth + timepoint_data = self.trainer.datamodule.get_timepoint_data() + + # Ground truth at t1 (intermediate timepoint) + data_t1 = torch.tensor(timepoint_data['t1'], dtype=torch.float32) + + # Define color schemes for mouse (2 branches) + custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"] + custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] + custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1) + custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2) + + t_span_full = torch.linspace(0, 1.0, 100).to(device) + all_trajs = [] + + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ).to(device) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span_full).cpu() # [T, B, D] + + traj = torch.transpose(traj, 0, 1) # [B, T, D] + all_trajs.append(traj) + + t_span_metric_t1 = torch.linspace(0, 0.5, 50).to(device) + t_span_metric_t2 = torch.linspace(0, 1.0, 100).to(device) + n_trials = 5 + + # Gather t2 branch ground truth + data_t2_branches = [] + for i in range(len(self.flow_nets)): + key = f't2_{i+1}' + if key in timepoint_data: + data_t2_branches.append(torch.tensor(timepoint_data[key], dtype=torch.float32)) + elif i == 0 and 't2' in timepoint_data: + data_t2_branches.append(torch.tensor(timepoint_data['t2'], dtype=torch.float32)) + else: + data_t2_branches.append(None) + + # Combined t2 ground truth (all branches merged) + data_t2_all_list = [d for d in data_t2_branches if d is not None] + data_t2_combined = torch.cat(data_t2_all_list, dim=0) if data_t2_all_list else None + + # ---- t1 combined metrics (all branches pooled, compared to t1) ---- + w1_t1_trials, w2_t1_trials, mmd_t1_trials = [], [], [] + + for trial in range(n_trials): + all_preds = [] + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ).to(device) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span_metric_t1) # [T, B, D] + + x_final = traj[-1].cpu() # [B, D] + all_preds.append(x_final) + + preds = torch.cat(all_preds, dim=0) + target_size = preds.shape[0] + perm = torch.randperm(data_t1.shape[0])[:target_size] + data_t1_reduced = data_t1[perm] + + metrics = compute_distribution_distances( + preds[:, :2], data_t1_reduced[:, :2] + ) + w1_t1_trials.append(metrics["W1"]) + w2_t1_trials.append(metrics["W2"]) + mmd_t1_trials.append(metrics["MMD"]) + + # ---- t2 per-branch metrics (each branch endpoint vs its own t2 cluster) ---- + branch_t2_metrics = {} + for i, flow_net in enumerate(self.flow_nets): + if data_t2_branches[i] is None: + continue + w1_br, w2_br, mmd_br = [], [], [] + for trial in range(n_trials): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ).to(device) + with torch.no_grad(): + traj = node.trajectory(x0, t_span_metric_t2) + x_final = traj[-1].cpu() + gt = data_t2_branches[i] + n_min = min(x_final.shape[0], gt.shape[0]) + perm_pred = torch.randperm(x_final.shape[0])[:n_min] + perm_gt = torch.randperm(gt.shape[0])[:n_min] + m = compute_distribution_distances( + x_final[perm_pred, :2], gt[perm_gt, :2] + ) + w1_br.append(m["W1"]) + w2_br.append(m["W2"]) + mmd_br.append(m["MMD"]) + branch_t2_metrics[f"branch_{i+1}_t2"] = { + "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), + "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), + "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), + } + print(f"Branch {i+1} @ t2 — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " + f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") + + # ---- t2 combined metrics (all branches pooled, compared to all t2) ---- + w1_t2_trials, w2_t2_trials, mmd_t2_trials = [], [], [] + if data_t2_combined is not None: + for trial in range(n_trials): + all_preds = [] + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ).to(device) + with torch.no_grad(): + traj = node.trajectory(x0, t_span_metric_t2) + all_preds.append(traj[-1].cpu()) + preds = torch.cat(all_preds, dim=0) + n_min = min(preds.shape[0], data_t2_combined.shape[0]) + perm_pred = torch.randperm(preds.shape[0])[:n_min] + perm_gt = torch.randperm(data_t2_combined.shape[0])[:n_min] + m = compute_distribution_distances( + preds[perm_pred, :2], data_t2_combined[perm_gt, :2] + ) + w1_t2_trials.append(m["W1"]) + w2_t2_trials.append(m["W2"]) + mmd_t2_trials.append(m["MMD"]) + + # Compute mean and std + w1_t1_mean, w1_t1_std = np.mean(w1_t1_trials), np.std(w1_t1_trials, ddof=1) + w2_t1_mean, w2_t1_std = np.mean(w2_t1_trials), np.std(w2_t1_trials, ddof=1) + mmd_t1_mean, mmd_t1_std = np.mean(mmd_t1_trials), np.std(mmd_t1_trials, ddof=1) + + # Log metrics + self.log("test/W1_combined_t1", w1_t1_mean, on_epoch=True) + self.log("test/W2_combined_t1", w2_t1_mean, on_epoch=True) + self.log("test/MMD_combined_t1", mmd_t1_mean, on_epoch=True) + + metrics_dict = { + "combined_t1": { + "W1_mean": float(w1_t1_mean), "W1_std": float(w1_t1_std), + "W2_mean": float(w2_t1_mean), "W2_std": float(w2_t1_std), + "MMD_mean": float(mmd_t1_mean), "MMD_std": float(mmd_t1_std), + "n_trials": n_trials, + } + } + metrics_dict.update(branch_t2_metrics) + + if w1_t2_trials: + w1_t2_mean, w1_t2_std = np.mean(w1_t2_trials), np.std(w1_t2_trials, ddof=1) + w2_t2_mean, w2_t2_std = np.mean(w2_t2_trials), np.std(w2_t2_trials, ddof=1) + mmd_t2_mean, mmd_t2_std = np.mean(mmd_t2_trials), np.std(mmd_t2_trials, ddof=1) + self.log("test/W1_combined_t2", w1_t2_mean, on_epoch=True) + self.log("test/W2_combined_t2", w2_t2_mean, on_epoch=True) + self.log("test/MMD_combined_t2", mmd_t2_mean, on_epoch=True) + metrics_dict["combined_t2"] = { + "W1_mean": float(w1_t2_mean), "W1_std": float(w1_t2_std), + "W2_mean": float(w2_t2_mean), "W2_std": float(w2_t2_std), + "MMD_mean": float(mmd_t2_mean), "MMD_std": float(mmd_t2_std), + "n_trials": n_trials, + } + + print(f"\n=== Combined @ t1 ===") + print(f"W1: {w1_t1_mean:.6f} ± {w1_t1_std:.6f}") + print(f"W2: {w2_t1_mean:.6f} ± {w2_t1_std:.6f}") + print(f"MMD: {mmd_t1_mean:.6f} ± {mmd_t1_std:.6f}") + if w1_t2_trials: + print(f"\n=== Combined @ t2 ===") + print(f"W1: {w1_t2_mean:.6f} ± {w1_t2_std:.6f}") + print(f"W2: {w2_t2_mean:.6f} ± {w2_t2_std:.6f}") + print(f"MMD: {mmd_t2_mean:.6f} ± {mmd_t2_std:.6f}") + + # Create results directory structure + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + figures_dir = f'{results_dir}/figures' + os.makedirs(figures_dir, exist_ok=True) + + # Save metrics to JSON + metrics_path = f'{results_dir}/metrics.json' + with open(metrics_path, 'w') as f: + json.dump(metrics_dict, f, indent=2) + print(f"Metrics saved to {metrics_path}") + + + # Save detailed metrics to CSV + detailed_csv_path = f'{results_dir}/metrics_detailed.csv' + with open(detailed_csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) + for key in sorted(metrics_dict.keys()): + m = metrics_dict[key] + writer.writerow([key, + f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}', + f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}', + f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}']) + print(f"Detailed metrics CSV saved to {detailed_csv_path}") + + # ===== Plot individual branches (using full t_span trajectories) ===== + self._plot_mouse_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2) + + # ===== Plot all branches together ===== + self._plot_mouse_combined(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2) + + print(f"Mouse figures saved to {figures_dir}") + + def _plot_mouse_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2): + """Plot each branch separately with timepoint background.""" + n_branches = len(all_trajs) + branch_names = [f'Branch {i+1}' for i in range(n_branches)] + branch_colors = ['#B83CFF', '#50B2D7'][:n_branches] + cmaps = [cmap1, cmap2][:n_branches] + + # Stack list-of-tensors into [B, T, D] numpy arrays + all_trajs_np = [] + for traj in all_trajs: + if isinstance(traj, list): + traj = torch.stack(traj, dim=1) # list of [B,D] -> [B,T,D] + all_trajs_np.append(traj.cpu().detach().numpy()) + all_trajs = all_trajs_np + + # Move timepoint data to numpy + for key in list(timepoint_data.keys()): + if torch.is_tensor(timepoint_data[key]): + timepoint_data[key] = timepoint_data[key].cpu().numpy() + + # Compute global axis limits + all_coords = [] + for key in ['t0', 't1', 't2', 't2_1', 't2_2']: + if key in timepoint_data: + all_coords.append(timepoint_data[key][:, :2]) + for traj_np in all_trajs: + all_coords.append(traj_np.reshape(-1, traj_np.shape[-1])[:, :2]) + + all_coords = np.concatenate(all_coords, axis=0) + x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() + y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() + + # Add margin + x_margin = 0.05 * (x_max - x_min) + y_margin = 0.05 * (y_max - y_min) + x_min -= x_margin + x_max += x_margin + y_min -= y_margin + y_max += y_margin + + for i, traj in enumerate(all_trajs): + fig, ax = plt.subplots(figsize=(10, 8)) + cmap = cmaps[i] + c_end = branch_colors[i] + + # Plot timepoint background + t2_key = f't2_{i+1}' if f't2_{i+1}' in timepoint_data else 't2' + coords_list = [timepoint_data['t0'], timepoint_data['t1'], timepoint_data[t2_key]] + tp_colors = ['#05009E', '#A19EFF', c_end] + tp_labels = ["t=0", "t=1", f"t=2 (branch {i+1})"] + + for coords, color, label in zip(coords_list, tp_colors, tp_labels): + alpha = 0.8 if color == '#05009E' else 0.6 + ax.scatter(coords[:, 0], coords[:, 1], + c=color, s=80, alpha=alpha, marker='x', + label=f'{label} cells', linewidth=1.5) + + # Plot continuous trajectories with LineCollection for speed + traj_2d = traj[:, :, :2] + n_time = traj_2d.shape[1] + color_vals = cmap(np.linspace(0, 1, n_time)) + segments = [] + seg_colors = [] + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] # [T, 2] + segs = np.stack([pts[:-1], pts[1:]], axis=1) + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + # Start and end points + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', label='Trajectory Start', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', label='Trajectory End', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + ax.set_xlabel("PC1", fontsize=12) + ax.set_ylabel("PC2", fontsize=12) + ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14) + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=12, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) + plt.close() + + def _plot_mouse_combined(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2): + """Plot all branches together.""" + n_branches = len(all_trajs) + branch_names = [f'Branch {i+1}' for i in range(n_branches)] + branch_colors = ['#B83CFF', '#50B2D7'][:n_branches] + + # Build timepoint key/color/label lists depending on branching + if 't2_1' in timepoint_data: + tp_keys = ['t0', 't1', 't2_1', 't2_2'] + tp_colors = ['#05009E', '#A19EFF', '#B83CFF', '#50B2D7'] + tp_labels = ['t=0', 't=1', 't=2 (branch 1)', 't=2 (branch 2)'] + else: + tp_keys = ['t0', 't1', 't2'] + tp_colors = ['#05009E', '#A19EFF', '#B83CFF'] + tp_labels = ['t=0', 't=1', 't=2'] + + # Stack list-of-tensors into [B, T, D] numpy arrays + all_trajs_np = [] + for traj in all_trajs: + if isinstance(traj, list): + traj = torch.stack(traj, dim=1) + if torch.is_tensor(traj): + traj = traj.cpu().detach().numpy() + all_trajs_np.append(traj) + all_trajs = all_trajs_np + + # Move timepoint data to numpy + for key in list(timepoint_data.keys()): + if torch.is_tensor(timepoint_data[key]): + timepoint_data[key] = timepoint_data[key].cpu().numpy() + + fig, ax = plt.subplots(figsize=(12, 10)) + + # Plot timepoint background + for idx, (t_key, color, label) in enumerate(zip( + tp_keys, + tp_colors, + tp_labels + )): + if t_key in timepoint_data: + coords = timepoint_data[t_key] + ax.scatter(coords[:, 0], coords[:, 1], + c=color, s=80, alpha=0.4, marker='x', + label=f'{label} cells', linewidth=1.5) + + # Plot trajectories with color gradients + cmaps = [cmap1, cmap2] + for i, traj in enumerate(all_trajs): + traj_2d = traj[:, :, :2] + c_end = branch_colors[i] + cmap = cmaps[i] + n_time = traj_2d.shape[1] + color_vals = cmap(np.linspace(0, 1, n_time)) + segments = [] + seg_colors = [] + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] + segs = np.stack([pts[:-1], pts[1:]], axis=1) + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', + label=f'{branch_names[i]} Start', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', + label=f'{branch_names[i]} End', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlabel("PC1", fontsize=14) + ax.set_ylabel("PC2", fontsize=14) + ax.set_title("All Branch Trajectories with Timepoint Background", + fontsize=16, weight='bold') + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=12, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) + plt.close() + + +class FlowNetTestClonidine(BranchFlowNetTrainBase): + """Test class for Clonidine perturbation experiment (1 or 2 branches).""" + + def test_step(self, batch, batch_idx): + # Handle both dict and tuple batch formats from CombinedLoader + if isinstance(batch, dict) and "test_samples" in batch: + # New format: {"test_samples": {...}, "metric_samples": {...}} + main_batch = batch["test_samples"] + elif isinstance(batch, (list, tuple)) and len(batch) >= 1: + # Old format with nested structure + test_samples = batch[0] + if isinstance(test_samples, dict) and "test_samples" in test_samples: + main_batch = test_samples["test_samples"][0] + else: + main_batch = test_samples + else: + # Fallback + main_batch = batch + + # Get timepoint data + timepoint_data = self.trainer.datamodule.get_timepoint_data() + device = main_batch["x0"][0].device + + # Use val x0 as initial conditions + x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) + t_span = torch.linspace(0, 1, 100).to(device) + + # Define color schemes for clonidine (2 branches) + custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"] + custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] + custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1) + custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2) + + all_trajs = [] + all_endpoints = [] + + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span).cpu() # [T, B, D] + + traj = torch.transpose(traj, 0, 1) # [B, T, D] + all_trajs.append(traj) + all_endpoints.append(traj[:, -1, :]) + + # Run 5 trials with random subsampling for robust metrics + n_trials = 5 + n_branches = len(self.flow_nets) + + # Gather per-branch ground truth + gt_data_per_branch = [] + for i in range(n_branches): + if n_branches == 1: + key = 't1' + else: + key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' + gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32)) + gt_all = torch.cat(gt_data_per_branch, dim=0) + + # Per-branch metrics (5 trials) + metrics_dict = {} + for i in range(n_branches): + w1_br, w2_br, mmd_br = [], [], [] + pred = all_endpoints[i] + gt = gt_data_per_branch[i] + for trial in range(n_trials): + n_min = min(pred.shape[0], gt.shape[0]) + perm_pred = torch.randperm(pred.shape[0])[:n_min] + perm_gt = torch.randperm(gt.shape[0])[:n_min] + m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2]) + w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) + metrics_dict[f"branch_{i+1}"] = { + "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), + "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), + "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), + } + self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True) + print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " + f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") + + # Combined metrics (5 trials) + pred_all = torch.cat(all_endpoints, dim=0) + w1_trials, w2_trials, mmd_trials = [], [], [] + for trial in range(n_trials): + n_min = min(pred_all.shape[0], gt_all.shape[0]) + perm_pred = torch.randperm(pred_all.shape[0])[:n_min] + perm_gt = torch.randperm(gt_all.shape[0])[:n_min] + m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2]) + w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) + + w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) + w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) + mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) + self.log("test/W1_t1_combined", w1_mean, on_epoch=True) + self.log("test/W2_t1_combined", w2_mean, on_epoch=True) + self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True) + metrics_dict['t1_combined'] = { + "W1_mean": float(w1_mean), "W1_std": float(w1_std), + "W2_mean": float(w2_mean), "W2_std": float(w2_std), + "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), + "n_trials": n_trials, + } + print(f"\n=== Combined @ t1 ===") + print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") + print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") + print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") + + # Create results directory structure + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + figures_dir = f'{results_dir}/figures' + os.makedirs(figures_dir, exist_ok=True) + + # Save metrics to JSON + metrics_path = f'{results_dir}/metrics.json' + with open(metrics_path, 'w') as f: + json.dump(metrics_dict, f, indent=2) + print(f"Metrics saved to {metrics_path}") + + # Save detailed metrics to CSV + detailed_csv_path = f'{results_dir}/metrics_detailed.csv' + with open(detailed_csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) + for key in sorted(metrics_dict.keys()): + m = metrics_dict[key] + writer.writerow([key, + f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}', + f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}', + f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}']) + print(f"Detailed metrics CSV saved to {detailed_csv_path}") + + # ===== Plot branches ===== + self._plot_clonidine_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2) + self._plot_clonidine_combined(all_trajs, timepoint_data, figures_dir) + + print(f"Clonidine figures saved to {figures_dir}") + + def _plot_clonidine_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2): + """Plot each branch separately.""" + branch_names = ['Branch 1', 'Branch 2'] + branch_colors = ['#B83CFF', '#50B2D7'] + cmaps = [cmap1, cmap2] + + # Compute global axis limits – handle single vs multi branch keys + all_coords = [] + if 't1_1' in timepoint_data: + tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))] + else: + tp_keys = ['t0', 't1'] + for key in tp_keys: + all_coords.append(timepoint_data[key][:, :2]) + for traj in all_trajs: + all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2]) + + all_coords = np.concatenate(all_coords, axis=0) + x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() + y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() + + x_margin = 0.05 * (x_max - x_min) + y_margin = 0.05 * (y_max - y_min) + x_min -= x_margin + x_max += x_margin + y_min -= y_margin + y_max += y_margin + + for i, traj in enumerate(all_trajs): + fig, ax = plt.subplots(figsize=(10, 8)) + c_end = branch_colors[i] + + # Plot timepoint background + t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' + coords_list = [timepoint_data['t0'], timepoint_data[t1_key]] + tp_colors = ['#05009E', c_end] + t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1" + tp_labels = ["t=0", t1_label] + + for coords, color, label in zip(coords_list, tp_colors, tp_labels): + ax.scatter(coords[:, 0], coords[:, 1], + c=color, s=80, alpha=0.4, marker='x', + label=f'{label} cells', linewidth=1.5) + + # Plot continuous trajectories with LineCollection for speed + traj_2d = traj[:, :, :2] + n_time = traj_2d.shape[1] + color_vals = cmaps[i](np.linspace(0, 1, n_time)) + segments = [] + seg_colors = [] + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] + segs = np.stack([pts[:-1], pts[1:]], axis=1) + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + # Start and end points + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', label='Trajectory Start', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', label='Trajectory End', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + ax.set_xlabel("PC1", fontsize=12) + ax.set_ylabel("PC2", fontsize=12) + ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14) + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=16, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) + plt.close() + + def _plot_clonidine_combined(self, all_trajs, timepoint_data, save_dir): + """Plot all branches together.""" + branch_names = ['Branch 1', 'Branch 2'] + branch_colors = ['#B83CFF', '#50B2D7'] + + fig, ax = plt.subplots(figsize=(12, 10)) + + # Build timepoint keys/colors/labels depending on single vs multi branch + if 't1_1' in timepoint_data: + tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))] + tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))] + else: + tp_keys = ['t0', 't1'] + tp_labels_list = ['t=0', 't=1'] + tp_colors = ['#05009E', '#B83CFF', '#50B2D7'][:len(tp_keys)] + + # Plot timepoint background + for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list): + coords = timepoint_data[t_key] + ax.scatter(coords[:, 0], coords[:, 1], + c=color, s=80, alpha=0.4, marker='x', + label=f'{label} cells', linewidth=1.5) + + # Plot trajectories with color gradients + custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"] + custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] + cmaps = [ + LinearSegmentedColormap.from_list("clon_cmap1", custom_colors_1), + LinearSegmentedColormap.from_list("clon_cmap2", custom_colors_2), + ] + for i, traj in enumerate(all_trajs): + traj_2d = traj[:, :, :2] + c_end = branch_colors[i] + cmap = cmaps[i] + n_time = traj_2d.shape[1] + color_vals = cmap(np.linspace(0, 1, n_time)) + segments = [] + seg_colors = [] + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] + segs = np.stack([pts[:-1], pts[1:]], axis=1) + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', + label=f'{branch_names[i]} Start', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', + label=f'{branch_names[i]} End', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlabel("PC1", fontsize=14) + ax.set_ylabel("PC2", fontsize=14) + ax.set_title("All Branch Trajectories with Timepoint Background", + fontsize=16, weight='bold') + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=12, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) + plt.close() + + +class FlowNetTestTrametinib(BranchFlowNetTrainBase): + """Test class for Trametinib perturbation experiment (1 or 3 branches).""" + + def test_step(self, batch, batch_idx): + # Handle both dict and tuple batch formats from CombinedLoader + if isinstance(batch, dict) and "test_samples" in batch: + # New format: {"test_samples": {...}, "metric_samples": {...}} + main_batch = batch["test_samples"] + elif isinstance(batch, (list, tuple)) and len(batch) >= 1: + # Old format with nested structure + test_samples = batch[0] + if isinstance(test_samples, dict) and "test_samples" in test_samples: + main_batch = test_samples["test_samples"][0] + else: + main_batch = test_samples + else: + # Fallback + main_batch = batch + + # Get timepoint data + timepoint_data = self.trainer.datamodule.get_timepoint_data() + device = main_batch["x0"][0].device + + # Use val x0 as initial conditions + x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) + t_span = torch.linspace(0, 1, 100).to(device) + + # Define color schemes for trametinib (3 branches) + custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"] + custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] + custom_colors_3 = ["#05009E", "#A19EFF", "#B83CFF"] + custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1) + custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2) + custom_cmap_3 = LinearSegmentedColormap.from_list("cmap3", custom_colors_3) + + all_trajs = [] + all_endpoints = [] + + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span).cpu() # [T, B, D] + + traj = torch.transpose(traj, 0, 1) # [B, T, D] + all_trajs.append(traj) + all_endpoints.append(traj[:, -1, :]) + + # Run 5 trials with random subsampling for robust metrics + n_trials = 5 + n_branches = len(self.flow_nets) + + # Gather per-branch ground truth + gt_data_per_branch = [] + for i in range(n_branches): + if n_branches == 1: + key = 't1' + else: + key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' + gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32)) + gt_all = torch.cat(gt_data_per_branch, dim=0) + + # Per-branch metrics (5 trials) + metrics_dict = {} + for i in range(n_branches): + w1_br, w2_br, mmd_br = [], [], [] + pred = all_endpoints[i] + gt = gt_data_per_branch[i] + for trial in range(n_trials): + n_min = min(pred.shape[0], gt.shape[0]) + perm_pred = torch.randperm(pred.shape[0])[:n_min] + perm_gt = torch.randperm(gt.shape[0])[:n_min] + m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2]) + w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) + metrics_dict[f"branch_{i+1}"] = { + "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), + "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), + "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), + } + self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True) + print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " + f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") + + # Combined metrics (5 trials) + pred_all = torch.cat(all_endpoints, dim=0) + w1_trials, w2_trials, mmd_trials = [], [], [] + for trial in range(n_trials): + n_min = min(pred_all.shape[0], gt_all.shape[0]) + perm_pred = torch.randperm(pred_all.shape[0])[:n_min] + perm_gt = torch.randperm(gt_all.shape[0])[:n_min] + m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2]) + w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) + + w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) + w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) + mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) + self.log("test/W1_t1_combined", w1_mean, on_epoch=True) + self.log("test/W2_t1_combined", w2_mean, on_epoch=True) + self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True) + metrics_dict['t1_combined'] = { + "W1_mean": float(w1_mean), "W1_std": float(w1_std), + "W2_mean": float(w2_mean), "W2_std": float(w2_std), + "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), + "n_trials": n_trials, + } + print(f"\n=== Combined @ t1 ===") + print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") + print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") + print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") + + # Create results directory structure + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + figures_dir = f'{results_dir}/figures' + os.makedirs(figures_dir, exist_ok=True) + + # Save metrics to JSON + metrics_path = f'{results_dir}/metrics.json' + with open(metrics_path, 'w') as f: + json.dump(metrics_dict, f, indent=2) + print(f"Metrics saved to {metrics_path}") + + # Save detailed metrics to CSV + detailed_csv_path = f'{results_dir}/metrics_detailed.csv' + with open(detailed_csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std']) + for key in sorted(metrics_dict.keys()): + m = metrics_dict[key] + writer.writerow([key, + f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}', + f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}', + f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}']) + print(f"Detailed metrics CSV saved to {detailed_csv_path}") + + # ===== Plot branches ===== + self._plot_trametinib_branches(all_trajs, timepoint_data, figures_dir, + custom_cmap_1, custom_cmap_2, custom_cmap_3) + self._plot_trametinib_combined(all_trajs, timepoint_data, figures_dir) + + print(f"Trametinib figures saved to {figures_dir}") + + def _plot_trametinib_branches(self, all_trajs, timepoint_data, save_dir, + cmap1, cmap2, cmap3): + """Plot each branch separately.""" + branch_names = ['Branch 1', 'Branch 2', 'Branch 3'] + branch_colors = ['#9793F8', '#50B2D7', '#B83CFF'] + cmaps = [cmap1, cmap2, cmap3] + + # Compute global axis limits – handle single vs multi branch keys + all_coords = [] + if 't1_1' in timepoint_data: + tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))] + else: + tp_keys = ['t0', 't1'] + for key in tp_keys: + all_coords.append(timepoint_data[key][:, :2]) + for traj in all_trajs: + all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2]) + + all_coords = np.concatenate(all_coords, axis=0) + x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() + y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() + + x_margin = 0.05 * (x_max - x_min) + y_margin = 0.05 * (y_max - y_min) + x_min -= x_margin + x_max += x_margin + y_min -= y_margin + y_max += y_margin + + for i, traj in enumerate(all_trajs): + fig, ax = plt.subplots(figsize=(10, 8)) + c_end = branch_colors[i] + + # Plot timepoint background + t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1' + coords_list = [timepoint_data['t0'], timepoint_data[t1_key]] + tp_colors = ['#05009E', c_end] + t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1" + tp_labels = ["t=0", t1_label] + + for coords, color, label in zip(coords_list, tp_colors, tp_labels): + ax.scatter(coords[:, 0], coords[:, 1], + c=color, s=80, alpha=0.4, marker='x', + label=f'{label} cells', linewidth=1.5) + + # Plot continuous trajectories with LineCollection for speed + traj_2d = traj[:, :, :2] + n_time = traj_2d.shape[1] + color_vals = cmaps[i](np.linspace(0, 1, n_time)) + segments = [] + seg_colors = [] + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] + segs = np.stack([pts[:-1], pts[1:]], axis=1) + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + # Start and end points + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', label='Trajectory Start', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', label='Trajectory End', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + ax.set_xlabel("PC1", fontsize=12) + ax.set_ylabel("PC2", fontsize=12) + ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14) + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=16, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) + plt.close() + + def _plot_trametinib_combined(self, all_trajs, timepoint_data, save_dir): + """Plot all 3 branches together.""" + branch_names = ['Branch 1', 'Branch 2', 'Branch 3'] + branch_colors = ['#9793F8', '#50B2D7', '#B83CFF'] + + fig, ax = plt.subplots(figsize=(12, 10)) + + # Build timepoint keys/colors/labels depending on single vs multi branch + if 't1_1' in timepoint_data: + tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))] + tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))] + else: + tp_keys = ['t0', 't1'] + tp_labels_list = ['t=0', 't=1'] + tp_colors = ['#05009E', '#9793F8', '#50B2D7', '#B83CFF'][:len(tp_keys)] + + # Plot timepoint background + for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list): + coords = timepoint_data[t_key] + ax.scatter(coords[:, 0], coords[:, 1], + c=color, s=80, alpha=0.4, marker='x', + label=f'{label} cells', linewidth=1.5) + + # Plot trajectories with color gradients + custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"] + custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"] + custom_colors_3 = ["#05009E", "#A19EFF", "#D577FF"] + cmaps = [ + LinearSegmentedColormap.from_list("tram_cmap1", custom_colors_1), + LinearSegmentedColormap.from_list("tram_cmap2", custom_colors_2), + LinearSegmentedColormap.from_list("tram_cmap3", custom_colors_3), + ] + for i, traj in enumerate(all_trajs): + traj_2d = traj[:, :, :2] + c_end = branch_colors[i] + cmap = cmaps[i] + n_time = traj_2d.shape[1] + color_vals = cmap(np.linspace(0, 1, n_time)) + segments = [] + seg_colors = [] + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] + segs = np.stack([pts[:-1], pts[1:]], axis=1) + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', + label=f'{branch_names[i]} Start', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', + label=f'{branch_names[i]} End', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlabel("PC1", fontsize=14) + ax.set_ylabel("PC2", fontsize=14) + ax.set_title("All Branch Trajectories with Timepoint Background", + fontsize=16, weight='bold') + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=12, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) + plt.close() + +class FlowNetTestVeres(GrowthNetTrain): + """Test class for Veres pancreatic endocrinogenesis experiment (3 or 5 branches).""" + + def test_step(self, batch, batch_idx): + # Handle both tuple and dict batch formats from CombinedLoader + if isinstance(batch, dict): + main_batch = batch["test_samples"][0] + metric_batch = batch["metric_samples"][0] + else: + # batch is a list/tuple + if isinstance(batch[0], dict): + # batch[0] contains the dict with test_samples and metric_samples + main_batch = batch[0]["test_samples"][0] + metric_batch = batch[0]["metric_samples"][0] + else: + # batch is a tuple: (test_samples, metric_samples) + main_batch = batch[0][0] + metric_batch = batch[1][0] + + # Get timepoint data (full datasets, not just val split) + timepoint_data = self.trainer.datamodule.get_timepoint_data() + device = main_batch["x0"][0].device + + # Use val x0 as initial conditions + x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device) + w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device) + full_batch = {"x0": (x0_all, w0_all)} + + time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch) + + n_branches = len(self.flow_nets) + + # trajectory time grid + t_span = torch.linspace(0, 1, 101).to(device) + + # `all_trajs` returned from `get_mass_and_position` is expected to be a list where each + # element is a sequence of per-timepoint tensors for that branch (shape [B, D] each). + # Convert each branch to [T, B, D] then to [B, T, D] for downstream processing. + trajs_TBD = [torch.stack(branch_list, dim=0) for branch_list in all_trajs] # each is [T, B, D] + trajs_BTD = [t.permute(1, 0, 2) for t in trajs_TBD] # each -> [B, T, D] + + all_trajs = [] + all_endpoints = [] + # will store per-branch intermediate frames: each entry -> tensor [B, n_intermediate, D] + all_intermediates = [] + + for traj in trajs_BTD: + # traj is [B, T, D] + # optionally inverse-transform if whitened + if self.whiten: + traj_np = traj.detach().cpu().numpy() + n_samples, n_time, n_dims = traj_np.shape + traj_flat = traj_np.reshape(-1, n_dims) + traj_inv_flat = self.trainer.datamodule.scaler.inverse_transform(traj_flat) + traj_inv = traj_inv_flat.reshape(n_samples, n_time, n_dims) + traj = torch.tensor(traj_inv, dtype=torch.float32) + + all_trajs.append(traj) + + # Collect six evenly spaced intermediate frames between t=0 and t=1 (exclude endpoints) + n_T = traj.shape[1] + # choose 8 points including endpoints -> take inner 6 as intermediates + inter_times = np.linspace(0.0, 1.0, 8)[1:-1] # 6 values + inter_indices = [int(round(t * (n_T - 1))) for t in inter_times] + # stack per-branch intermediate frames -> [B, 6, D] + intermediates = torch.stack([traj[:, idx, :] for idx in inter_indices], dim=1) + all_intermediates.append(intermediates) + + # Final endpoints (t=1) + all_endpoints.append(traj[:, -1, :]) + + # Run 5 trials with random subsampling for robust metrics + n_trials = 5 + metrics_dict = {} + + # --- Intermediate timepoints (t1-t6) combined metrics --- + intermediate_keys = sorted([k for k in timepoint_data.keys() + if k.startswith('t') and '_' not in k and k != 't0']) + + if intermediate_keys: + n_evals = min(6, len(intermediate_keys)) + for j in range(n_evals): + intermediate_key = intermediate_keys[j] + true_data_intermediate = torch.tensor(timepoint_data[intermediate_key], dtype=torch.float32) + + # Gather predicted intermediates across all branches + raw_intermediates = [branch[:, j, :] for branch in all_intermediates] + all_raw_concat = torch.cat(raw_intermediates, dim=0).cpu() # [n_branches*B, D] + + w1_t, w2_t, mmd_t = [], [], [] + w1_t_full, w2_t_full, mmd_t_full = [], [], [] + for trial in range(n_trials): + n_min = min(all_raw_concat.shape[0], true_data_intermediate.shape[0]) + perm_pred = torch.randperm(all_raw_concat.shape[0])[:n_min] + perm_gt = torch.randperm(true_data_intermediate.shape[0])[:n_min] + # 2D metrics (PC1-PC2) + m = compute_distribution_distances( + all_raw_concat[perm_pred, :2], true_data_intermediate[perm_gt, :2]) + w1_t.append(m["W1"]); w2_t.append(m["W2"]); mmd_t.append(m["MMD"]) + # Full-dimensional metrics (all PCs) + m_full = compute_distribution_distances( + all_raw_concat[perm_pred], true_data_intermediate[perm_gt]) + w1_t_full.append(m_full["W1"]); w2_t_full.append(m_full["W2"]); mmd_t_full.append(m_full["MMD"]) + + metrics_dict[f'{intermediate_key}_combined'] = { + "W1_mean": float(np.mean(w1_t)), "W1_std": float(np.std(w1_t, ddof=1)), + "W2_mean": float(np.mean(w2_t)), "W2_std": float(np.std(w2_t, ddof=1)), + "MMD_mean": float(np.mean(mmd_t)), "MMD_std": float(np.std(mmd_t, ddof=1)), + "W1_full_mean": float(np.mean(w1_t_full)), "W1_full_std": float(np.std(w1_t_full, ddof=1)), + "W2_full_mean": float(np.mean(w2_t_full)), "W2_full_std": float(np.std(w2_t_full, ddof=1)), + "MMD_full_mean": float(np.mean(mmd_t_full)), "MMD_full_std": float(np.std(mmd_t_full, ddof=1)), + } + self.log(f"test/W1_{intermediate_key}_combined", np.mean(w1_t), on_epoch=True) + self.log(f"test/W1_full_{intermediate_key}_combined", np.mean(w1_t_full), on_epoch=True) + print(f"{intermediate_key} combined — W1: {np.mean(w1_t):.6f}±{np.std(w1_t, ddof=1):.6f}, " + f"W2: {np.mean(w2_t):.6f}±{np.std(w2_t, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_t):.6f}±{np.std(mmd_t, ddof=1):.6f}") + print(f"{intermediate_key} combined (full) — W1: {np.mean(w1_t_full):.6f}±{np.std(w1_t_full, ddof=1):.6f}, " + f"W2: {np.mean(w2_t_full):.6f}±{np.std(w2_t_full, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_t_full):.6f}±{np.std(mmd_t_full, ddof=1):.6f}") + + # --- Final timepoint per-branch metrics --- + gt_keys = sorted([k for k in timepoint_data.keys() if k.startswith('t7_')]) + for i, endpoints in enumerate(all_endpoints): + true_data_key = f"t7_{i}" + if true_data_key not in timepoint_data: + print(f"Warning: {true_data_key} not found in timepoint_data") + continue + gt = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32) + pred = endpoints.cpu() + + w1_br, w2_br, mmd_br = [], [], [] + w1_br_full, w2_br_full, mmd_br_full = [], [], [] + for trial in range(n_trials): + n_min = min(pred.shape[0], gt.shape[0]) + perm_pred = torch.randperm(pred.shape[0])[:n_min] + perm_gt = torch.randperm(gt.shape[0])[:n_min] + # 2D metrics (PC1-PC2) + m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2]) + w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"]) + # Full-dimensional metrics (all PCs) + m_full = compute_distribution_distances(pred[perm_pred], gt[perm_gt]) + w1_br_full.append(m_full["W1"]); w2_br_full.append(m_full["W2"]); mmd_br_full.append(m_full["MMD"]) + + metrics_dict[f"branch_{i}"] = { + "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)), + "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)), + "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)), + "W1_full_mean": float(np.mean(w1_br_full)), "W1_full_std": float(np.std(w1_br_full, ddof=1)), + "W2_full_mean": float(np.mean(w2_br_full)), "W2_full_std": float(np.std(w2_br_full, ddof=1)), + "MMD_full_mean": float(np.mean(mmd_br_full)), "MMD_full_std": float(np.std(mmd_br_full, ddof=1)), + } + self.log(f"test/W1_branch{i}", np.mean(w1_br), on_epoch=True) + self.log(f"test/W1_full_branch{i}", np.mean(w1_br_full), on_epoch=True) + print(f"Branch {i} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, " + f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}") + print(f"Branch {i} (full) — W1: {np.mean(w1_br_full):.6f}±{np.std(w1_br_full, ddof=1):.6f}, " + f"W2: {np.mean(w2_br_full):.6f}±{np.std(w2_br_full, ddof=1):.6f}, " + f"MMD: {np.mean(mmd_br_full):.6f}±{np.std(mmd_br_full, ddof=1):.6f}") + + # --- Final timepoint combined metrics --- + gt_list = [torch.tensor(timepoint_data[k], dtype=torch.float32) for k in gt_keys] + if len(gt_list) > 0 and len(all_endpoints) > 0: + gt_all = torch.cat(gt_list, dim=0) + pred_all = torch.cat([e.cpu() for e in all_endpoints], dim=0) + + w1_trials, w2_trials, mmd_trials = [], [], [] + w1_trials_full, w2_trials_full, mmd_trials_full = [], [], [] + for trial in range(n_trials): + n_min = min(pred_all.shape[0], gt_all.shape[0]) + perm_pred = torch.randperm(pred_all.shape[0])[:n_min] + perm_gt = torch.randperm(gt_all.shape[0])[:n_min] + # 2D metrics (PC1-PC2) + m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2]) + w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"]) + # Full-dimensional metrics (all PCs) + m_full = compute_distribution_distances(pred_all[perm_pred], gt_all[perm_gt]) + w1_trials_full.append(m_full["W1"]); w2_trials_full.append(m_full["W2"]); mmd_trials_full.append(m_full["MMD"]) + + w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1) + w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1) + mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1) + w1_mean_f, w1_std_f = np.mean(w1_trials_full), np.std(w1_trials_full, ddof=1) + w2_mean_f, w2_std_f = np.mean(w2_trials_full), np.std(w2_trials_full, ddof=1) + mmd_mean_f, mmd_std_f = np.mean(mmd_trials_full), np.std(mmd_trials_full, ddof=1) + self.log("test/W1_t7_combined", w1_mean, on_epoch=True) + self.log("test/W2_t7_combined", w2_mean, on_epoch=True) + self.log("test/MMD_t7_combined", mmd_mean, on_epoch=True) + self.log("test/W1_full_t7_combined", w1_mean_f, on_epoch=True) + self.log("test/W2_full_t7_combined", w2_mean_f, on_epoch=True) + self.log("test/MMD_full_t7_combined", mmd_mean_f, on_epoch=True) + metrics_dict['t7_combined'] = { + "W1_mean": float(w1_mean), "W1_std": float(w1_std), + "W2_mean": float(w2_mean), "W2_std": float(w2_std), + "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std), + "W1_full_mean": float(w1_mean_f), "W1_full_std": float(w1_std_f), + "W2_full_mean": float(w2_mean_f), "W2_full_std": float(w2_std_f), + "MMD_full_mean": float(mmd_mean_f), "MMD_full_std": float(mmd_std_f), + "n_trials": n_trials, + } + print(f"\n=== Combined @ t7 ===") + print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}") + print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}") + print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}") + print(f"W1 (full): {w1_mean_f:.6f} ± {w1_std_f:.6f}") + print(f"W2 (full): {w2_mean_f:.6f} ± {w2_std_f:.6f}") + print(f"MMD (full): {mmd_mean_f:.6f} ± {mmd_std_f:.6f}") + + # Create results directory structure + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + figures_dir = f'{results_dir}/figures' + os.makedirs(figures_dir, exist_ok=True) + + # Save metrics to JSON + metrics_path = f'{results_dir}/metrics.json' + with open(metrics_path, 'w') as f: + json.dump(metrics_dict, f, indent=2) + print(f"Metrics saved to {metrics_path}") + + # Save detailed metrics to CSV + detailed_csv_path = f'{results_dir}/metrics_detailed.csv' + with open(detailed_csv_path, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Metric_Group', + 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std', + 'W1_Full_Mean', 'W1_Full_Std', 'W2_Full_Mean', 'W2_Full_Std', 'MMD_Full_Mean', 'MMD_Full_Std']) + for key in sorted(metrics_dict.keys()): + m = metrics_dict[key] + writer.writerow([key, + f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}', + f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}', + f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}', + f'{m.get("W1_full_mean", 0):.6f}', f'{m.get("W1_full_std", 0):.6f}', + f'{m.get("W2_full_mean", 0):.6f}', f'{m.get("W2_full_std", 0):.6f}', + f'{m.get("MMD_full_mean", 0):.6f}', f'{m.get("MMD_full_std", 0):.6f}']) + print(f"Detailed metrics CSV saved to {detailed_csv_path}") + + # ===== Plot branches ===== + self._plot_veres_branches(all_trajs, timepoint_data, figures_dir, n_branches) + self._plot_veres_combined(all_trajs, timepoint_data, figures_dir, n_branches) + + print(f"Veres figures saved to {figures_dir}") + + def _plot_veres_branches(self, all_trajs, timepoint_data, save_dir, n_branches): + """Plot each branch separately in PCA space (PC1 vs PC2).""" + branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9', + '#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8', + '#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8', + '#64B5F6', '#C5E1A5'] + + # Project to first 2 PCs (data is already in PCA space) + t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2] + t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)] + + # Slice trajectories to first 2 PCs + trajs_2d = [] + for traj in all_trajs: + trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2] + + # Compute global axis limits + all_coords = [t0_2d] + t7_2d + for traj_2d in trajs_2d: + all_coords.append(traj_2d.reshape(-1, 2)) + + all_coords = np.concatenate(all_coords, axis=0) + x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max() + y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max() + + x_margin = 0.05 * (x_max - x_min) + y_margin = 0.05 * (y_max - y_min) + x_min -= x_margin + x_max += x_margin + y_min -= y_margin + y_max += y_margin + + for i, traj_2d in enumerate(trajs_2d): + fig, ax = plt.subplots(figsize=(10, 8)) + c_end = branch_colors[i % len(branch_colors)] + + # Plot timepoint background + ax.scatter(t0_2d[:, 0], t0_2d[:, 1], + c='#05009E', s=80, alpha=0.4, marker='x', + label='t=0 cells', linewidth=1.5) + ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1], + c=c_end, s=80, alpha=0.4, marker='x', + label=f't=7 (branch {i+1}) cells', linewidth=1.5) + + # Plot continuous trajectories with LineCollection for speed + cmap_colors = ["#05009E", "#A19EFF", c_end] + cmap = LinearSegmentedColormap.from_list(f"veres_cmap_{i}", cmap_colors) + n_time = traj_2d.shape[1] + segments = [] + seg_colors = [] + color_vals = cmap(np.linspace(0, 1, n_time)) + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] # [T, 2] + segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2] + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8) + ax.add_collection(lc) + + # Start and end points + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=30, marker='o', label='Trajectory start (t=0)', + zorder=5, edgecolors='white', linewidth=1) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=30, marker='o', label='Trajectory end (t=1)', + zorder=5, edgecolors='white', linewidth=1) + + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + ax.set_xlabel("PC 1", fontsize=12) + ax.set_ylabel("PC 2", fontsize=12) + ax.set_title(f"Branch {i+1}: Trajectories (PCA)", fontsize=14) + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=9, frameon=False) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300) + plt.close() + + def _plot_veres_combined(self, all_trajs, timepoint_data, save_dir, n_branches): + """Plot all branches together in PCA space (PC1 vs PC2).""" + branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9', + '#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8', + '#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8', + '#64B5F6', '#C5E1A5'] + + # Project to first 2 PCs (data is already in PCA space) + t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2] + t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)] + + # Slice trajectories to first 2 PCs + trajs_2d = [] + for traj in all_trajs: + trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2] + + # Compute axis limits from REAL CELLS ONLY + all_coords_real = [t0_2d] + t7_2d + all_coords_real = np.concatenate(all_coords_real, axis=0) + x_min, x_max = all_coords_real[:, 0].min(), all_coords_real[:, 0].max() + y_min, y_max = all_coords_real[:, 1].min(), all_coords_real[:, 1].max() + x_margin = 0.05 * (x_max - x_min) + y_margin = 0.05 * (y_max - y_min) + x_min -= x_margin + x_max += x_margin + y_min -= y_margin + y_max += y_margin + + fig, ax = plt.subplots(figsize=(14, 12)) + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + + # Plot t=0 cells + ax.scatter(t0_2d[:, 0], t0_2d[:, 1], + c='#05009E', s=60, alpha=0.3, marker='x', + label='t=0 cells', linewidth=1.5) + + # Plot each branch's cells and trajectories + for i, traj_2d in enumerate(trajs_2d): + c_end = branch_colors[i % len(branch_colors)] + + # Plot t=7 cells for this branch + ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1], + c=c_end, s=60, alpha=0.3, marker='x', + label=f't=7 (branch {i+1})', linewidth=1.5) + + # Plot continuous trajectories with LineCollection for speed + cmap_colors = ["#05009E", "#A19EFF", c_end] + cmap = LinearSegmentedColormap.from_list(f"veres_combined_cmap_{i}", cmap_colors) + n_time = traj_2d.shape[1] + segments = [] + seg_colors = [] + color_vals = cmap(np.linspace(0, 1, n_time)) + for j in range(traj_2d.shape[0]): + pts = traj_2d[j] # [T, 2] + segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2] + segments.append(segs) + seg_colors.append(color_vals[:-1]) + segments = np.concatenate(segments, axis=0) + seg_colors = np.concatenate(seg_colors, axis=0) + lc = LineCollection(segments, colors=seg_colors, linewidths=1.5, alpha=0.6) + ax.add_collection(lc) + + # Start and end points + ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1], + c='#05009E', s=20, marker='o', + zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7) + ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1], + c=c_end, s=20, marker='o', + zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7) + + ax.set_xlabel("PC 1", fontsize=14) + ax.set_ylabel("PC 2", fontsize=14) + ax.set_title(f"All {n_branches} Branch Trajectories (Veres) - PCA Projection", + fontsize=16, weight='bold') + ax.grid(True, alpha=0.3) + ax.legend(loc='upper right', fontsize=10, frameon=False, ncol=2) + + plt.tight_layout() + plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300) + plt.close() \ No newline at end of file diff --git a/src/branch_flow_net_train.py b/src/branch_flow_net_train.py new file mode 100755 index 0000000000000000000000000000000000000000..7a1fc1ca10d59977b4875f01ff1575775b279838 --- /dev/null +++ b/src/branch_flow_net_train.py @@ -0,0 +1,375 @@ +import os +import sys +import torch +import wandb +import matplotlib.pyplot as plt +import pytorch_lightning as pl +from torch.optim import AdamW +from torchmetrics.functional import mean_squared_error +from torchdyn.core import NeuralODE +from .networks.utils import flow_model_torch_wrapper +from .utils import wasserstein, plot_lidar +from .ema import EMA + +class BranchFlowNetTrainBase(pl.LightningModule): + def __init__( + self, + flow_matcher, + flow_nets, + skipped_time_points=None, + ot_sampler=None, + args=None, + ): + super().__init__() + self.args = args + + self.flow_matcher = flow_matcher + self.flow_nets = flow_nets # list of flow networks for each branch + self.ot_sampler = ot_sampler + self.skipped_time_points = skipped_time_points + + self.optimizer_name = args.flow_optimizer + self.lr = args.flow_lr + self.weight_decay = args.flow_weight_decay + self.whiten = args.whiten + self.working_dir = args.working_dir + + #branching + self.branches = len(flow_nets) + + def forward(self, t, xt, branch_idx): + # output velocity given branch_idx + return self.flow_nets[branch_idx](t, xt) + + def _compute_loss(self, main_batch): + + x0s = [main_batch["x0"][0]] + w0s = [main_batch["x0"][1]] + + x1s_list = [] + w1s_list = [] + + if self.branches > 1: + for i in range(self.branches): + x1s_list.append([main_batch[f"x1_{i+1}"][0]]) + w1s_list.append([main_batch[f"x1_{i+1}"][1]]) + else: + x1s_list.append([main_batch["x1"][0]]) + w1s_list.append([main_batch["x1"][1]]) + + assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" + + loss = 0 + for branch_idx in range(self.branches): + ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) + + t = torch.cat(ts) + xt = torch.cat(xts) + ut = torch.cat(uts) + vt = self(t[:, None], xt, branch_idx) + + loss += mean_squared_error(vt, ut) + + return loss + + def _process_flow(self, x0s, x1s, branch_idx): + ts, xts, uts = [], [], [] + t_start = self.timesteps[0] + + for i, (x0, x1) in enumerate(zip(x0s, x1s)): + + x0, x1 = torch.squeeze(x0), torch.squeeze(x1) + + if self.ot_sampler is not None: + x0, x1 = self.ot_sampler.sample_plan( + x0, + x1, + replace=True, + ) + if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: + t_start_next = self.timesteps[i + 2] + else: + t_start_next = self.timesteps[i + 1] + + # edit to sample from correct flow matcher + t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( + x0, x1, t_start, t_start_next, branch_idx + ) + + ts.append(t) + + xts.append(xt) + uts.append(ut) + t_start = t_start_next + return ts, xts, uts + + def training_step(self, batch, batch_idx): + # Handle both dict and tuple batch formats from CombinedLoader + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "train_samples" in batch: + main_batch = batch["train_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + else: + # Fallback + main_batch = batch.get("train_samples", batch) + + print("Main batch length") + print(len(main_batch["x0"])) + + # edited to simulate 100 steps + self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() + loss = self._compute_loss(main_batch) + if self.flow_matcher.alpha != 0: + self.log( + "FlowNet/mean_geopath_cfm", + (self.flow_matcher.geopath_net_output.abs().mean()), + on_step=False, + on_epoch=True, + prog_bar=True, + ) + + self.log( + "FlowNet/train_loss_cfm", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + + return loss + + def validation_step(self, batch, batch_idx): + # Handle both dict and tuple batch formats from CombinedLoader + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "val_samples" in batch: + main_batch = batch["val_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + else: + # Fallback + main_batch = batch.get("val_samples", batch) + + self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() + val_loss = self._compute_loss(main_batch) + self.log( + "FlowNet/val_loss_cfm", + val_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return val_loss + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + + for net in self.flow_nets: + if isinstance(net, EMA): + net.update_ema() + + def configure_optimizers(self): + if self.optimizer_name == "adamw": + optimizer = AdamW( + self.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + elif self.optimizer_name == "adam": + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.lr, + ) + + return optimizer + + +class FlowNetTrainTrajectory(BranchFlowNetTrainBase): + def test_step(self, batch, batch_idx): + data_type = self.args.data_type + node = NeuralODE( + flow_model_torch_wrapper(self.flow_nets), + solver="euler", + sensitivity="adjoint", + atol=1e-5, + rtol=1e-5, + ) + + t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None + if t_exclude is not None: + traj = node.trajectory( + batch[t_exclude - 1], + t_span=torch.linspace( + self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101 + ), + ) + X_mid_pred = traj[-1] + traj = node.trajectory( + batch[t_exclude - 1], + t_span=torch.linspace( + self.timesteps[t_exclude - 1], + self.timesteps[t_exclude + 1], + 101, + ), + ) + + EMD = wasserstein(X_mid_pred, batch[t_exclude], p=1) + self.final_EMD = EMD + + self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True) + +class FlowNetTrainCell(BranchFlowNetTrainBase): + def test_step(self, batch, batch_idx): + x0 = batch[0]["test_samples"][0]["x0"][0] # [B, D] + dataset_points = batch[0]["test_samples"][0]["dataset"][0] # full dataset, [N, D] + t_span = torch.linspace(0, 1, 101) + + all_trajs = [] + + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span).cpu() # [T, B, D] + + if self.whiten: + traj_shape = traj.shape + traj = traj.reshape(-1, traj.shape[-1]) + traj = self.trainer.datamodule.scaler.inverse_transform( + traj.cpu().detach().numpy() + ).reshape(traj_shape) + dataset_points = self.trainer.datamodule.scaler.inverse_transform( + dataset_points.cpu().detach().numpy() + ) + + traj = torch.tensor(traj) + traj = torch.transpose(traj, 0, 1) # [B, T, D] + all_trajs.append(traj) + + dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2] + + # ===== Plot all 2D trajectories together with dataset and start/end points ===== + fig, ax = plt.subplots(figsize=(6, 5)) + dataset_2d = dataset_2d.cpu().numpy() + ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1) + for traj in all_trajs: + traj_2d = traj[..., :2] # [B, T, 2] + for i in range(traj_2d.shape[0]): + ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2) + ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3) + ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3) + + ax.set_title("All Branch Trajectories (2D) with Dataset") + ax.set_xlabel("x") + ax.set_ylabel("y") + plt.axis("equal") + handles, labels = ax.get_legend_handles_labels() + if labels: + ax.legend() + + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + save_path = os.path.join(results_dir, 'figures') + + os.makedirs(save_path, exist_ok=True) + plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300) + plt.close() + + # ===== Plot each 2D trajectory separately with dataset and endpoints ===== + for i, traj in enumerate(all_trajs): + traj_2d = traj[..., :2] + fig, ax = plt.subplots(figsize=(6, 5)) + ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1) + for j in range(traj_2d.shape[0]): + ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2) + ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3) + ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3) + + ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset") + ax.set_xlabel("x") + ax.set_ylabel("y") + plt.axis("equal") + handles, labels = ax.get_legend_handles_labels() + if labels: + ax.legend() + plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300) + plt.close() + +class FlowNetTrainLidar(BranchFlowNetTrainBase): + def test_step(self, batch, batch_idx): + # Handle both tuple and dict batch formats from CombinedLoader + if isinstance(batch, dict): + main_batch = batch["test_samples"][0] + metric_batch = batch["metric_samples"][0] + else: + # batch is a tuple: (test_samples, metric_samples) + main_batch = batch[0][0] + metric_batch = batch[1][0] + + x0 = main_batch["x0"][0] # [B, D] + cloud_points = main_batch["dataset"][0] # full dataset, [N, D] + t_span = torch.linspace(0, 1, 101) + + all_trajs = [] + + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span).cpu() # [T, B, D] + + if self.whiten: + traj_shape = traj.shape + traj = traj.reshape(-1, 3) + traj = self.trainer.datamodule.scaler.inverse_transform( + traj.cpu().detach().numpy() + ).reshape(traj_shape) + + traj = torch.tensor(traj) + traj = torch.transpose(traj, 0, 1) # [B, T, D] + all_trajs.append(traj) + + # Inverse-transform the point cloud once + if self.whiten: + cloud_points = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform( + cloud_points.cpu().detach().numpy() + ) + ) + + # Create directory for saving figures + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + lidar_fig_dir = os.path.join(results_dir, 'figures') + os.makedirs(lidar_fig_dir, exist_ok=True) + + # ===== Plot all trajectories together ===== + fig = plt.figure(figsize=(6, 5)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + for i, traj in enumerate(all_trajs): + plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) + plt.savefig(os.path.join(lidar_fig_dir, 'lidar_all_branches.png'), dpi=300) + plt.close() + + # ===== Plot each trajectory separately ===== + for i, traj in enumerate(all_trajs): + fig = plt.figure(figsize=(6, 5)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) + plt.savefig(os.path.join(lidar_fig_dir, f'lidar_branch_{i + 1}.png'), dpi=300) + plt.close() \ No newline at end of file diff --git a/src/branch_growth_net_train.py b/src/branch_growth_net_train.py new file mode 100755 index 0000000000000000000000000000000000000000..941ca4d1c5acd4c3cef64635d62aba71435566ba --- /dev/null +++ b/src/branch_growth_net_train.py @@ -0,0 +1,994 @@ +import os +import sys +import torch +import wandb +import matplotlib.pyplot as plt +import pytorch_lightning as pl +from torch.optim import AdamW +from torchmetrics.functional import mean_squared_error +from torchdyn.core import NeuralODE +import numpy as np +import lpips +from .networks.utils import flow_model_torch_wrapper +from .utils import plot_lidar +from .ema import EMA +from torchdiffeq import odeint as odeint2 +from .losses.energy_loss import EnergySolver, ReconsLoss + +class GrowthNetTrain(pl.LightningModule): + def __init__( + self, + flow_nets, + growth_nets, + skipped_time_points=None, + ot_sampler=None, + args=None, + + state_cost=None, + data_manifold_metric=None, + + joint = False + ): + super().__init__() + #self.save_hyperparameters() + self.flow_nets = flow_nets + + if not joint: + for param in self.flow_nets.parameters(): + param.requires_grad = False + + self.growth_nets = growth_nets # list of growth networks for each branch + + self.ot_sampler = ot_sampler + self.skipped_time_points = skipped_time_points + + self.optimizer_name = args.growth_optimizer + self.lr = args.growth_lr + self.weight_decay = args.growth_weight_decay + self.whiten = args.whiten + self.working_dir = args.working_dir + + self.args = args + + #branching + self.state_cost = state_cost + self.data_manifold_metric = data_manifold_metric + self.branches = len(growth_nets) + self.metric_clusters = args.metric_clusters + + self.recons_loss = ReconsLoss() + + # loss weights + self.lambda_energy = args.lambda_energy + self.lambda_mass = args.lambda_mass + self.lambda_match = args.lambda_match + self.lambda_recons = args.lambda_recons + + self.joint = joint + + def forward(self, t, xt, branch_idx): + # output growth rate given branch_idx + return self.growth_nets[branch_idx](t, xt) + + def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False): + x0s = main_batch["x0"][0] + w0s = main_batch["x0"][1] + x1s_list = [] + w1s_list = [] + + if self.branches > 1: + for i in range(self.branches): + x1s_list.append([main_batch[f"x1_{i+1}"][0]]) + w1s_list.append([main_batch[f"x1_{i+1}"][1]]) + else: + x1s_list.append([main_batch["x1"][0]]) + w1s_list.append([main_batch["x1"][1]]) + + if self.args.manifold: + #changed + if self.metric_clusters == 7 and self.branches == 6: + # Weinreb 6-branch scenario: cluster 0 (root) → clusters 1-6 (6 branches) + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + (metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 (branch 3) + (metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 (branch 4) + (metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 (branch 5) + (metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 (branch 6) + ] + elif self.metric_clusters == 4: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + (metric_samples_batch[0], metric_samples_batch[3]), + ] + elif self.metric_clusters == 3: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2 and self.branches == 2: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2: + # For any number of branches with 2 metric clusters (initial vs remaining) + # All branches use the same metric cluster pair + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches + ] * self.branches + else: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + ] + + batch_size = x0s.shape[0] + + assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" + + energy_loss = [0.] * self.branches + mass_loss = 0. + neg_weight_penalty = 0. + match_loss = [0.] * self.branches + recons_loss = [0.] * self.branches + + dtype = x0s[0].dtype + #w0s = torch.zeros((batch_size, 1), dtype=dtype) + m0s = torch.zeros_like(w0s, dtype=dtype) + start_state = (x0s, w0s, m0s) + + xt = [x0s.clone() for _ in range(self.branches)] + w0_branch = torch.zeros_like(w0s, dtype=dtype) + w0_branches = [] + w0_branches.append(w0s) + for _ in range(self.branches - 1): + w0_branches.append(w0_branch) + #w0_branches = [w0_branch.clone() for _ in range(self.branches - 1)] + wt = w0_branches + + mt = [m0s.clone() for _ in range(self.branches)] + + # loop through timesteps + for step_idx, (s, t) in enumerate(zip(self.timesteps[:-1], self.timesteps[1:])): + time = torch.Tensor([s, t]) + + total_w_t = 0 + # loop through branches + for i in range(self.branches): + + if self.args.manifold: + start_samples, end_samples = branch_sample_pairs[i] + samples = torch.cat([start_samples, end_samples], dim=0) + else: + samples = None + + # initialize weight and energy + start_state = (xt[i], wt[i], mt[i]) + + # loop over timesteps + xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx) + + # placeholders for next state + xt_last = xt_next[-1] + wt_last = wt_next[-1] + mt_last = mt_next[-1] + + total_w_t += wt_last + + energy_loss[i] += (mt_last - mt[i]) + neg_weight_penalty += torch.relu(-wt_last).sum() + + # update branch state + xt[i] = xt_last.clone().detach() + wt[i] = wt_last.clone().detach() + mt[i] = mt_last.clone().detach() + + # calculate mass loss from all branches + target = torch.ones_like(total_w_t) + mass_loss += mean_squared_error(total_w_t, target) + + # calculate loss that matches final weights + for i in range(self.branches): + match_loss[i] = mean_squared_error(wt[i], w1s_list[i][0]) + # compute reconstruction loss + recons_loss[i] = self.recons_loss(xt[i], x1s_list[i][0]) + + # average across time steps (loop runs len(timesteps)-1 times) + mass_loss = mass_loss / max(len(self.timesteps) - 1, 1) + + + # Weighted mean across branches (inversely weighted by cluster size) + # Get cluster sizes from datamodule if available + if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'): + cluster_sizes = self.trainer.datamodule.cluster_sizes + max_size = max(cluster_sizes) + # Inverse weighting: smaller clusters get higher weight + branch_weights = torch.tensor([max_size / size for size in cluster_sizes], + dtype=energy_loss[0].dtype, device=energy_loss[0].device) + # Normalize weights to sum to num_branches for fair comparison + branch_weights = branch_weights * self.branches / branch_weights.sum() + + energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss]) * branch_weights) + match_loss = torch.mean(torch.stack(match_loss) * branch_weights) + recons_loss = torch.mean(torch.stack(recons_loss) * branch_weights) + else: + # Fallback to uniform weighting + energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss])) + match_loss = torch.mean(torch.stack(match_loss)) + recons_loss = torch.mean(torch.stack(recons_loss)) + + loss = (self.lambda_energy * energy_loss) + (self.lambda_mass * (mass_loss + neg_weight_penalty)) + (self.lambda_match * match_loss) \ + + (self.lambda_recons * recons_loss) + + if self.joint: + if validation: + self.log("JointTrain/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) + else: + self.log("JointTrain/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) + else: + if validation: + self.log("GrowthNet/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) + else: + self.log("GrowthNet/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True) + + return loss + + def take_step(self, t, start_state, branch_idx, samples=None, timestep_idx=0): + + flow_net = self.flow_nets[branch_idx] + growth_net = self.growth_nets[branch_idx] + + + x_t, w_t, m_t = odeint2(EnergySolver(flow_net, growth_net, self.state_cost, self.data_manifold_metric, samples, timestep_idx), start_state, t, options=dict(step_size=0.1),method='euler') + + return x_t, w_t, m_t + + def training_step(self, batch, batch_idx): + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "train_samples" in batch: + main_batch = batch["train_samples"] + metric_batch = batch["metric_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + else: + # Fallback + main_batch = batch.get("train_samples", batch) + metric_batch = batch.get("metric_samples", []) + + self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() + loss = self._compute_loss(main_batch, metric_batch, validation=False) + + if self.joint: + self.log( + "JointTrain/train_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + else: + self.log( + "GrowthNet/train_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return loss + + def validation_step(self, batch, batch_idx): + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "val_samples" in batch: + main_batch = batch["val_samples"] + metric_batch = batch["metric_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + else: + # Fallback + main_batch = batch.get("val_samples", batch) + metric_batch = batch.get("metric_samples", []) + + self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() + val_loss = self._compute_loss(main_batch, metric_batch, validation=True) + + if self.joint: + self.log( + "JointTrain/val_loss", + val_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + else: + self.log( + "GrowthNet/val_loss", + val_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return val_loss + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + for net in self.growth_nets: + if isinstance(net, EMA): + net.update_ema() + if self.joint: + for net in self.flow_nets: + if isinstance(net, EMA): + net.update_ema() + + def configure_optimizers(self): + params = [] + for net in self.growth_nets: + params += list(net.parameters()) + + if self.joint: + for net in self.flow_nets: + params += list(net.parameters()) + + if self.optimizer_name == "adamw": + optimizer = AdamW( + params, + lr=self.lr, + weight_decay=self.weight_decay, + ) + elif self.optimizer_name == "adam": + optimizer = torch.optim.Adam( + params, + lr=self.lr, + ) + + return optimizer + + @torch.no_grad() + def get_mass_and_position(self, main_batch, metric_samples_batch=None): + if isinstance(main_batch, dict): + main_batch = main_batch + else: + main_batch = main_batch[0] + + x0s = main_batch["x0"][0] + w0s = main_batch["x0"][1] + + if self.args.manifold: + if self.metric_clusters == 4: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + (metric_samples_batch[0], metric_samples_batch[3]), + ] + elif self.metric_clusters == 3: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2 and self.branches == 2: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2: + # For any number of branches with 2 metric clusters (initial vs remaining) + # All branches use the same metric cluster pair + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches + ] * self.branches + else: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + ] + + batch_size = x0s.shape[0] + dtype = x0s[0].dtype + + m0s = torch.zeros_like(w0s, dtype=dtype) + xt = [x0s.clone() for _ in range(self.branches)] + + w0_branch = torch.zeros_like(w0s, dtype=dtype) + w0_branches = [] + w0_branches.append(w0s) + for _ in range(self.branches - 1): + w0_branches.append(w0_branch) + + wt = w0_branches + mt = [m0s.clone() for _ in range(self.branches)] + + time_points = [] + mass_over_time = [[] for _ in range(self.branches)] + energy_over_time = [[] for _ in range(self.branches)] + # record per-sample weights at each time for each branch (to allow OT with per-sample masses) + weights_over_time = [[] for _ in range(self.branches)] + all_trajs = [[] for _ in range(self.branches)] + + t_span = torch.linspace(0, 1, 101) + for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): + time_points.append(t.item()) + time = torch.Tensor([s, t]) + + for i in range(self.branches): + if self.args.manifold: + start_samples, end_samples = branch_sample_pairs[i] + samples = torch.cat([start_samples, end_samples], dim=0) + else: + samples = None + + start_state = (xt[i], wt[i], mt[i]) + xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx) + + xt[i] = xt_next[-1].clone().detach() + wt[i] = wt_next[-1].clone().detach() + mt[i] = mt_next[-1].clone().detach() + + all_trajs[i].append(xt[i].clone().detach()) + mass_over_time[i].append(wt[i].mean().item()) + energy_over_time[i].append(mt[i].mean().item()) + # store per-sample weights (clone to detach from graph) + try: + weights_over_time[i].append(wt[i].clone().detach()) + except Exception: + # fallback: store mean as singleton tensor + weights_over_time[i].append(torch.tensor(wt[i].mean().item()).unsqueeze(0)) + + return time_points, xt, all_trajs, mass_over_time, energy_over_time, weights_over_time + + @torch.no_grad() + def _plot_mass_and_energy(self, main_batch, metric_samples_batch=None, save_dir=None): + x0s = main_batch["x0"][0] + w0s = main_batch["x0"][1] + + if self.args.manifold: + if self.metric_clusters == 7 and self.branches == 6: + # Weinreb 6-branch scenario: cluster 0 (root) → clusters 1-6 (6 branches) + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + (metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 (branch 3) + (metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 (branch 4) + (metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 (branch 5) + (metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 (branch 6) + ] + elif self.metric_clusters == 4: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + (metric_samples_batch[0], metric_samples_batch[3]), + ] + elif self.metric_clusters == 3: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2 and self.branches == 2: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2) + ] + else: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + ] + + batch_size = x0s.shape[0] + dtype = x0s[0].dtype + + m0s = torch.zeros_like(w0s, dtype=dtype) + xt = [x0s.clone() for _ in range(self.branches)] + + w0_branch = torch.zeros_like(w0s, dtype=dtype) + w0_branches = [] + w0_branches.append(w0s) + for _ in range(self.branches - 1): + w0_branches.append(w0_branch) + + wt = w0_branches + mt = [m0s.clone() for _ in range(self.branches)] + + time_points = [] + mass_over_time = [[] for _ in range(self.branches)] + energy_over_time = [[] for _ in range(self.branches)] + + t_span = torch.linspace(0, 1, 101) + for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])): + time_points.append(t.item()) + time = torch.Tensor([s, t]) + + for i in range(self.branches): + if self.args.manifold: + start_samples, end_samples = branch_sample_pairs[i] + samples = torch.cat([start_samples, end_samples], dim=0) + else: + samples = None + + start_state = (xt[i], wt[i], mt[i]) + xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx) + + xt[i] = xt_next[-1].clone().detach() + wt[i] = wt_next[-1].clone().detach() + mt[i] = mt_next[-1].clone().detach() + + mass_over_time[i].append(wt[i].mean().item()) + energy_over_time[i].append(mt[i].mean().item()) + + if save_dir is None: + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + save_dir = os.path.join(self.args.working_dir, 'results', run_name, 'figures') + os.makedirs(save_dir, exist_ok=True) + + # Use tab10 colormap to get visually distinct colors + if self.args.branches == 3: + branch_colors = ['#9793F8', '#50B2D7', '#D577FF'] # tuple of RGBs + else: + branch_colors = ['#50B2D7', '#D577FF'] # tuple of RGBs + + # --- Plot Mass --- + plt.figure(figsize=(8, 5)) + for i in range(self.branches): + color = branch_colors[i] + plt.plot(time_points, mass_over_time[i], color=color, linewidth=2.5, label=f"Mass Branch {i}") + plt.xlabel("Time") + plt.ylabel("Mass") + plt.title("Mass Evolution per Branch") + plt.legend() + plt.grid(True) + if self.joint: + mass_path = os.path.join(save_dir, f"{self.args.data_name}_joint_mass.png") + else: + mass_path = os.path.join(save_dir, f"{self.args.data_name}_growth_mass.png") + plt.savefig(mass_path, dpi=300, bbox_inches="tight") + plt.close() + + # --- Plot Energy --- + plt.figure(figsize=(8, 5)) + for i in range(self.branches): + color = branch_colors[i] + plt.plot(time_points, energy_over_time[i], color=color, linewidth=2.5, label=f"Energy Branch {i}") + plt.xlabel("Time") + plt.ylabel("Energy") + plt.title("Energy Evolution per Branch") + plt.legend() + plt.grid(True) + if self.joint: + energy_path = os.path.join(save_dir, f"{self.args.data_name}_joint_energy.png") + else: + energy_path = os.path.join(save_dir, f"{self.args.data_name}_growth_energy.png") + plt.savefig(energy_path, dpi=300, bbox_inches="tight") + plt.close() + + +class GrowthNetTrainLidar(GrowthNetTrain): + def test_step(self, batch, batch_idx): + # Handle both tuple and dict batch formats from CombinedLoader + if isinstance(batch, dict): + main_batch = batch["test_samples"][0] + metric_batch = batch["metric_samples"][0] + else: + # batch is a tuple: (test_samples, metric_samples) + main_batch = batch[0][0] + metric_batch = batch[1][0] + + self._plot_mass_and_energy(main_batch, metric_batch) + + x0 = main_batch["x0"][0] # [B, D] + cloud_points = main_batch["dataset"][0] # full dataset, [N, D] + t_span = torch.linspace(0, 1, 101) + + + all_trajs = [] + + for i, flow_net in enumerate(self.flow_nets): + node = NeuralODE( + flow_model_torch_wrapper(flow_net), + solver="euler", + sensitivity="adjoint", + ) + + with torch.no_grad(): + traj = node.trajectory(x0, t_span).cpu() # [T, B, D] + + if self.whiten: + traj_shape = traj.shape + traj = traj.reshape(-1, 3) + traj = self.trainer.datamodule.scaler.inverse_transform( + traj.cpu().detach().numpy() + ).reshape(traj_shape) + + traj = torch.tensor(traj) + traj = torch.transpose(traj, 0, 1) # [B, T, D] + all_trajs.append(traj) + + # Inverse-transform the point cloud once + if self.whiten: + cloud_points = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform( + cloud_points.cpu().detach().numpy() + ) + ) + + # ===== Plot all trajectories together ===== + fig = plt.figure(figsize=(6, 5)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + for i, traj in enumerate(all_trajs): + plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + lidar_fig_dir = os.path.join(results_dir, 'figures') + os.makedirs(lidar_fig_dir, exist_ok=True) + if self.joint: + plt.savefig(os.path.join(lidar_fig_dir, 'joint_lidar_all_branches.png'), dpi=300) + else: + plt.savefig(os.path.join(lidar_fig_dir, 'growth_lidar_all_branches.png'), dpi=300) + plt.close() + + # ===== Plot each trajectory separately ===== + for i, traj in enumerate(all_trajs): + fig = plt.figure(figsize=(6, 5)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + plot_lidar(ax, cloud_points, xs=traj, branch_idx=i) + if self.joint: + plt.savefig(os.path.join(lidar_fig_dir, f'joint_lidar_branch_{i + 1}.png'), dpi=300) + else: + plt.savefig(os.path.join(lidar_fig_dir, f'growth_lidar_branch_{i + 1}.png'), dpi=300) + plt.close() + +class GrowthNetTrainCell(GrowthNetTrain): + def test_step(self, batch, batch_idx): + if self.args.data_type in ["scrna", "tahoe"]: + main_batch = batch[0]["test_samples"][0] + metric_batch = batch[0]["metric_samples"][0] + else: + main_batch = batch["test_samples"][0] + metric_batch = batch["metric_samples"][0] + + self._plot_mass_and_energy(main_batch, metric_batch) + + +class SequentialGrowthNetTrain(pl.LightningModule): + """ + Sequential growth network training for multi-timepoint data. + Learns growth rates for transitions between consecutive timepoints. + """ + def __init__( + self, + flow_nets, + growth_nets, + skipped_time_points=None, + ot_sampler=None, + args=None, + data_manifold_metric=None, + joint=False + ): + super().__init__() + self.flow_nets = flow_nets + + if not joint: + for param in self.flow_nets.parameters(): + param.requires_grad = False + + self.growth_nets = growth_nets + self.ot_sampler = ot_sampler + self.skipped_time_points = skipped_time_points + + self.optimizer_name = args.growth_optimizer + self.lr = args.growth_lr + self.weight_decay = args.growth_weight_decay + self.whiten = args.whiten + self.working_dir = args.working_dir + + self.args = args + self.data_manifold_metric = data_manifold_metric + self.branches = len(growth_nets) + self.metric_clusters = args.metric_clusters + + self.recons_loss = ReconsLoss() + + # loss weights + self.lambda_energy = args.lambda_energy + self.lambda_mass = args.lambda_mass + self.lambda_match = args.lambda_match + self.lambda_recons = args.lambda_recons + + self.joint = joint + self.num_timepoints = None + self.timepoint_keys = None + + def forward(self, t, xt, branch_idx): + return self.growth_nets[branch_idx](t, xt) + + def setup(self, stage=None): + """Initialize timepoint keys before training/validation starts.""" + if self.timepoint_keys is None: + timepoint_data = self.trainer.datamodule.get_timepoint_data() + self.timepoint_keys = [k for k in sorted(timepoint_data.keys()) + if not any(x in k for x in ['_', 'time_labels'])] + self.num_timepoints = len(self.timepoint_keys) + print(f"Training sequential growth for {self.num_timepoints} timepoints: {self.timepoint_keys}") + + def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False): + """Compute loss for sequential growth between timepoints.""" + x0s = main_batch["x0"][0] + w0s = main_batch["x0"][1] + + # Setup metric sample pairs + if self.args.manifold: + if self.metric_clusters == 2: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]) + ] * self.branches + else: + branch_sample_pairs = [] + for b in range(self.branches): + if b + 1 < len(metric_samples_batch): + branch_sample_pairs.append( + (metric_samples_batch[0], metric_samples_batch[b + 1]) + ) + else: + branch_sample_pairs.append( + (metric_samples_batch[0], metric_samples_batch[1]) + ) + + total_loss = 0 + total_energy_loss = 0 + total_mass_loss = 0 + total_match_loss = 0 + total_recons_loss = 0 + num_transitions = 0 + + # Process each consecutive timepoint transition + for i in range(len(self.timepoint_keys) - 1): + t_curr_key = self.timepoint_keys[i] + t_next_key = self.timepoint_keys[i + 1] + + batch_curr_key = f"x{t_curr_key.replace('t', '').replace('final', '1')}" + x_curr = main_batch[batch_curr_key][0] + w_curr = main_batch[batch_curr_key][1] + + if i == len(self.timepoint_keys) - 2: + # Final transition to branches + # Get cluster size weights if available + if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'): + cluster_sizes = self.trainer.datamodule.cluster_sizes + max_size = max(cluster_sizes) + # Inverse weighting: smaller clusters get higher weight + branch_weights = [max_size / size for size in cluster_sizes] + else: + branch_weights = [1.0] * self.branches + + for b in range(self.branches): + x_next = main_batch[f"x1_{b+1}"][0] + w_next = main_batch[f"x1_{b+1}"][1] + + # Compute growth-based loss for this transition + loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss( + x_curr, w_curr, x_next, w_next, b, i, + branch_sample_pairs[b] if self.args.manifold else None + ) + # Apply branch weight + total_loss += loss * branch_weights[b] + total_energy_loss += energy_l * branch_weights[b] + total_mass_loss += mass_l * branch_weights[b] + total_match_loss += match_l * branch_weights[b] + total_recons_loss += recons_l * branch_weights[b] + num_transitions += 1 + else: + # Regular consecutive timepoints + batch_next_key = f"x{t_next_key.replace('t', '').replace('final', '1')}" + x_next = main_batch[batch_next_key][0] + w_next = main_batch[batch_next_key][1] + + for b in range(self.branches): + loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss( + x_curr, w_curr, x_next, w_next, b, i, + branch_sample_pairs[b] if self.args.manifold else None + ) + total_loss += loss + total_energy_loss += energy_l + total_mass_loss += mass_l + total_match_loss += match_l + total_recons_loss += recons_l + num_transitions += 1 + + # Average losses + avg_energy_loss = total_energy_loss / num_transitions if num_transitions > 0 else total_energy_loss + avg_mass_loss = total_mass_loss / num_transitions if num_transitions > 0 else total_mass_loss + avg_match_loss = total_match_loss / num_transitions if num_transitions > 0 else total_match_loss + avg_recons_loss = total_recons_loss / num_transitions if num_transitions > 0 else total_recons_loss + + # Log individual components + if self.joint: + if validation: + self.log("JointTrain/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) + else: + self.log("JointTrain/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("JointTrain/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) + else: + if validation: + self.log("GrowthNet/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) + else: + self.log("GrowthNet/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True) + self.log("GrowthNet/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True) + + return total_loss + + def _compute_transition_loss(self, x0, w0, x1, w1, branch_idx, transition_idx, metric_pair): + """Compute loss for a single timepoint transition.""" + if self.ot_sampler is not None: + x0, x1 = self.ot_sampler.sample_plan(x0, x1, replace=True) + + # Simulate trajectory using flow network + t_span = torch.linspace(0, 1, 10, device=x0.device) + + flow_model = flow_model_torch_wrapper(self.flow_nets[branch_idx]) + node = NeuralODE(flow_model, solver="euler", sensitivity="adjoint") + + with torch.no_grad(): + traj = node.trajectory(x0, t_span) + + # Compute energy and mass losses + energy_loss = 0 + mass_loss = 0 + neg_weight_penalty = 0 + + for t_idx in range(len(t_span)): + t = t_span[t_idx] + xt = traj[t_idx] + + # Growth rate + growth = self.growth_nets[branch_idx](t.unsqueeze(0).expand(xt.shape[0]), xt) + + # Energy loss + if self.args.manifold and metric_pair is not None: + start_samples, end_samples = metric_pair + samples = torch.cat([start_samples, end_samples], dim=0) + _, kinetic, potential = self.data_manifold_metric.calculate_velocity( + xt, torch.zeros_like(xt), samples, transition_idx + ) + energy = kinetic + potential + else: + energy = (growth ** 2).sum(dim=-1) + + energy_loss += energy.mean() + + # Mass conservation + growth_sum = growth.sum(dim=-1, keepdim=True) # Keep dimension for proper broadcasting + wt = w0 * torch.exp(growth_sum) + mass = wt.sum() + mass_loss += (mass - w1.sum()).abs() + neg_weight_penalty += torch.relu(-wt).sum() + + # Match and reconstruction losses (computed at final time) + xt_final = traj[-1] + match_loss = mean_squared_error(wt, w1) + recons_loss = self.recons_loss(xt_final, x1) + + total_loss = ( + self.lambda_energy * energy_loss + + self.lambda_mass * (mass_loss + neg_weight_penalty) + + self.lambda_match * match_loss + + self.lambda_recons * recons_loss + ) + + return total_loss, energy_loss, mass_loss + neg_weight_penalty, match_loss, recons_loss + + def training_step(self, batch, batch_idx): + if isinstance(batch, (list, tuple)): + batch = batch[0] + main_batch = batch["train_samples"] + metric_batch = batch["metric_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + + loss = self._compute_loss(main_batch, metric_batch) + + if self.joint: + self.log( + "JointTrain/train_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + else: + self.log( + "GrowthNet/train_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return loss + + def validation_step(self, batch, batch_idx): + if isinstance(batch, (list, tuple)): + batch = batch[0] + main_batch = batch["val_samples"] + metric_batch = batch["metric_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + + loss = self._compute_loss(main_batch, metric_batch, validation=True) + + if self.joint: + self.log( + "JointTrain/val_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + else: + self.log( + "GrowthNet/val_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return loss + + def configure_optimizers(self): + import itertools + params = list(itertools.chain(*[net.parameters() for net in self.growth_nets])) + if self.joint: + params += list(itertools.chain(*[net.parameters() for net in self.flow_nets])) + + if self.optimizer_name == "adam": + optimizer = torch.optim.Adam(params, lr=self.lr) + elif self.optimizer_name == "adamw": + optimizer = torch.optim.AdamW( + params, + lr=self.lr, + weight_decay=self.weight_decay, + ) + return optimizer \ No newline at end of file diff --git a/src/branch_interpolant_train.py b/src/branch_interpolant_train.py new file mode 100755 index 0000000000000000000000000000000000000000..77a921d66b0d657bc0ab4b0570f00081f58a6387 --- /dev/null +++ b/src/branch_interpolant_train.py @@ -0,0 +1,477 @@ +import sys +import os +import torch +import pytorch_lightning as pl +from .ema import EMA +import itertools +from .utils import plot_lidar +import matplotlib.pyplot as plt + +class BranchInterpolantTrain(pl.LightningModule): + def __init__( + self, + flow_matcher, + args, + skipped_time_points: list = None, + ot_sampler=None, + + state_cost=None, + data_manifold_metric=None, + ): + super().__init__() + self.save_hyperparameters() + self.args = args + + self.flow_matcher = flow_matcher + + # list of geopath nets + self.geopath_nets = flow_matcher.geopath_nets + self.branches = len(self.geopath_nets) + self.metric_clusters = args.metric_clusters + + self.ot_sampler = ot_sampler + self.skipped_time_points = skipped_time_points if skipped_time_points else [] + self.optimizer_name = args.geopath_optimizer + self.lr = args.geopath_lr + self.weight_decay = args.geopath_weight_decay + self.args = args + self.multiply_validation = 4 + + self.first_loss = None + self.timesteps = None + self.computing_reference_loss = False + + # updates + self.state_cost = state_cost + self.data_manifold_metric = data_manifold_metric + self.whiten = args.whiten + + def forward(self, x0, x1, t, branch_idx): + # return specific branch interpolant + return self.geopath_nets[branch_idx](x0, x1, t) + + def on_train_start(self): + self.first_loss = self.compute_initial_loss() + print("first loss") + print(self.first_loss) + + # to edit + def compute_initial_loss(self): + # Set all GeoPath networks to eval mode + for net in self.geopath_nets: + net.train(mode=False) + + total_loss = 0 + total_count = 0 + with torch.enable_grad(): + self.t_val = [] + for i in range( + self.trainer.datamodule.num_timesteps - len(self.skipped_time_points) + ): + self.t_val.append( + torch.rand( + self.trainer.datamodule.batch_size * self.multiply_validation, + requires_grad=True, + ) + ) + self.computing_reference_loss = True + with torch.no_grad(): + old_alpha = self.flow_matcher.alpha + self.flow_matcher.alpha = 0 + for batch in self.trainer.datamodule.train_dataloader(): + + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "train_samples" in batch: + main_batch_init = batch["train_samples"] + metric_batch_init = batch["metric_samples"] + if isinstance(main_batch_init, tuple): + main_batch_init = main_batch_init[0] + if isinstance(metric_batch_init, tuple): + metric_batch_init = metric_batch_init[0] + else: + main_batch_init = batch + metric_batch_init = [] + + self.timesteps = torch.linspace( + 0.0, 1.0, len(main_batch_init["x0"]) + ).tolist() + + loss = self._compute_loss( + main_batch_init, + metric_batch_init, + ) + print("initial loss") + print(loss) + # Skip NaN/Inf batches to prevent poisoning the average + if not (torch.isnan(loss) or torch.isinf(loss)): + total_loss += loss.item() + total_count += 1 + self.flow_matcher.alpha = old_alpha + + self.computing_reference_loss = False + + # Set all GeoPath networks back to training mode + for net in self.geopath_nets: + net.train(mode=True) + return total_loss / total_count if total_count > 0 else 1.0 + + def _compute_loss(self, main_batch, metric_samples_batch=None): + + x0s = [main_batch["x0"][0]] + w0s = [main_batch["x0"][1]] + + x1s_list = [] + w1s_list = [] + + if self.branches > 1: + for i in range(self.branches): + x1s_list.append([main_batch[f"x1_{i+1}"][0]]) + w1s_list.append([main_batch[f"x1_{i+1}"][1]]) + else: + x1s_list.append([main_batch["x1"][0]]) + w1s_list.append([main_batch["x1"][1]]) + + if self.args.manifold: + #changed + if self.metric_clusters == 7: + # For 6 branches with 7 clusters (1 root + 6 branch endpoints) + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 + (metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 + (metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 + (metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 + (metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 + ] + elif self.metric_clusters == 4: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + (metric_samples_batch[0], metric_samples_batch[3]), + ] + elif self.metric_clusters == 3: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2 and self.branches == 2: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[0]), # x0 → x1_1 (branch 1) + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2) + ] + elif self.metric_clusters == 2: + # For any number of branches with 2 metric clusters (initial vs remaining) + # All branches use the same metric cluster pair + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches + ] * self.branches + else: + branch_sample_pairs = [ + (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1) + ] + """samples0, samples1, samples2 = ( + metric_samples_batch[0], + metric_samples_batch[1], + metric_samples_batch[2] + )""" + + assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches" + + # compute sum of velocities for each branch + loss = 0 + velocities = [] + for branch_idx in range(self.branches): + + ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx) + + for i in range(len(ts)): + # calculate kinetic and potential energy of the predicted interpolant + if self.args.manifold: + start_samples, end_samples = branch_sample_pairs[branch_idx] + + samples = torch.cat([start_samples, end_samples], dim=0) + #print("metric sample shape") + #print(samples.shape) + vel, _, _ = self.data_manifold_metric.calculate_velocity( + xts[i], uts[i], samples, i + ) + else: + vel = torch.sqrt((uts[i]**2).sum(dim =-1) + self.state_cost(xts[i])) + #vel = (uts[i]**2).sum(dim =-1) + + velocities.append(vel) + + velocity_loss = torch.mean(torch.cat(velocities) ** 2) + + self.log( + "BranchPathNet/mean_velocity_geopath", + velocity_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + ) + + return velocity_loss + + def _process_flow(self, x0s, x1s, branch_idx): + ts, xts, uts = [], [], [] + t_start = self.timesteps[0] + i_start = 0 + + for i, (x0, x1) in enumerate(zip(x0s, x1s)): + x0, x1 = torch.squeeze(x0), torch.squeeze(x1) + if self.trainer.validating or self.computing_reference_loss: + repeat_tuple = (self.multiply_validation, 1) + (1,) * ( + len(x0.shape) - 2 + ) + x0 = x0.repeat(repeat_tuple) + x1 = x1.repeat(repeat_tuple) + + if self.ot_sampler is not None: + x0, x1 = self.ot_sampler.sample_plan( + x0, + x1, + replace=True, + ) + if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]: + t_start_next = self.timesteps[i + 2] + else: + t_start_next = self.timesteps[i + 1] + + t = None + if self.trainer.validating or self.computing_reference_loss: + t = self.t_val[i] + + t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow( + x0, x1, t_start, t_start_next, branch_idx, training_geopath_net=True, t=t + ) + ts.append(t) + xts.append(xt) + uts.append(ut) + t_start = t_start_next + + return ts, xts, uts + + def training_step(self, batch, batch_idx): + # Handle both dict and tuple batch formats from CombinedLoader + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "train_samples" in batch: + main_batch = batch["train_samples"] + metric_batch = batch["metric_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + else: + # Fallback + main_batch = batch.get("train_samples", batch) + metric_batch = batch.get("metric_samples", []) + + # Debug: print structure + if batch_idx == 0: + print(f"DEBUG batch type: {type(batch)}") + if isinstance(batch, dict): + print(f"DEBUG batch keys: {batch.keys()}") + print(f"DEBUG train_samples type: {type(batch.get('train_samples'))}") + if isinstance(batch.get("train_samples"), dict): + print(f"DEBUG train_samples keys: {batch['train_samples'].keys()}") + print(f"DEBUG x0 type: {type(batch['train_samples'].get('x0'))}") + if 'x0' in batch['train_samples']: + x0_item = batch['train_samples']['x0'] + print(f"DEBUG x0 structure: {type(x0_item)}") + if isinstance(x0_item, (list, tuple)): + print(f"DEBUG x0 length: {len(x0_item)}") + if len(x0_item) > 0: + print(f"DEBUG x0[0] shape: {x0_item[0].shape if hasattr(x0_item[0], 'shape') else 'no shape'}") + print(f"DEBUG main_batch type: {type(main_batch)}") + if isinstance(main_batch, dict): + print(f"DEBUG main_batch keys: {main_batch.keys()}") + + self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() + tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) + + if self.first_loss: + tangential_velocity_loss = tangential_velocity_loss / self.first_loss + + self.log( + "BranchPathNet/mean_geopath_geopath", + (self.flow_matcher.geopath_net_output.abs().mean()), + on_step=False, + on_epoch=True, + prog_bar=True, + ) + + self.log( + "BranchPathNet/train_loss_geopath", + tangential_velocity_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + return tangential_velocity_loss + + def validation_step(self, batch, batch_idx): + # Handle both dict and tuple batch formats from CombinedLoader + if isinstance(batch, (list, tuple)): + batch = batch[0] + if isinstance(batch, dict) and "val_samples" in batch: + main_batch = batch["val_samples"] + metric_batch = batch["metric_samples"] + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + else: + # Fallback + main_batch = batch.get("val_samples", batch) + metric_batch = batch.get("metric_samples", []) + + self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist() + tangential_velocity_loss = self._compute_loss(main_batch, metric_batch) + if self.first_loss: + tangential_velocity_loss = tangential_velocity_loss / self.first_loss + + self.log( + "BranchPathNet/val_loss_geopath", + tangential_velocity_loss, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return tangential_velocity_loss + + + def test_step(self, batch, batch_idx): + # Handle both tuple and dict batch formats from CombinedLoader + if isinstance(batch, dict): + main_batch = batch["test_samples"] + metric_batch = batch["metric_samples"] + # CombinedLoader may wrap values in a tuple + if isinstance(main_batch, tuple): + main_batch = main_batch[0] + if isinstance(metric_batch, tuple): + metric_batch = metric_batch[0] + else: + # batch is a tuple: (test_samples, metric_samples) + main_batch = batch[0][0] + metric_batch = batch[1][0] + + x0 = main_batch["x0"][0] # [B, D] + cloud_points = main_batch["dataset"][0] # full dataset, [N, D] + + x0 = x0.to(self.device) + cloud_points = cloud_points.to(self.device) + + t_vals = [0.25, 0.5, 0.75] + t_labels = ["t=1/4", "t=1/2", "t=3/4"] + + colors = { + "x0": "#4D176C", + "t=1/4": "#5C3B9D", + "t=1/2": "#6172B9", + "t=3/4": "#AC4E51", + "x1": "#771F4F", + } + + # Unwhiten cloud points if needed + if self.whiten: + cloud_points = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform(cloud_points.cpu().numpy()) + ) + + for i in range(self.branches): + geopath = self.geopath_nets[i] + x1_key = f"x1_{i + 1}" + if x1_key not in main_batch: + print(f"Skipping branch {i + 1}: no final distribution {x1_key}") + continue + + x1 = main_batch[x1_key][0].to(self.device) + print(x1.shape) + print(x0.shape) + interpolated_points = [] + with torch.no_grad(): + for t_scalar in t_vals: + t_tensor = torch.full((x0.shape[0], 1), t_scalar, device=self.device) # [B, 1] + xt = geopath(x0, x1, t_tensor).cpu() # [B, D] + if self.whiten: + xt = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform(xt.numpy()) + ) + interpolated_points.append(xt) + + if self.whiten: + x0_plot = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform(x0.cpu().numpy()) + ) + x1_plot = torch.tensor( + self.trainer.datamodule.scaler.inverse_transform(x1.cpu().numpy()) + ) + else: + x0_plot = x0.cpu() + x1_plot = x1.cpu() + + # Plot + fig = plt.figure(figsize=(6, 5)) + ax = fig.add_subplot(111, projection="3d", computed_zorder=False) + ax.view_init(elev=30, azim=-115, roll=0) + plot_lidar(ax, cloud_points) + + # Initial x₀ + ax.scatter( + x0_plot[:, 0], x0_plot[:, 1], x0_plot[:, 2], + s=15, alpha=1.0, color=colors["x0"], label="x₀", depthshade=True, + edgecolors="white", + linewidths=0.3 + ) + + # Interpolated points + for xt, t_label in zip(interpolated_points, t_labels): + ax.scatter( + xt[:, 0], xt[:, 1], xt[:, 2], + s=15, alpha=1.0, color=colors[t_label], label=t_label, depthshade=True, + edgecolors="white", + linewidths=0.3 + ) + + # Final x₁ + ax.scatter( + x1_plot[:, 0], x1_plot[:, 1], x1_plot[:, 2], + s=15, alpha=1.0, color=colors["x1"], label="x₁", depthshade=True, + edgecolors="white", + linewidths=0.3 + ) + + ax.legend() + + # Use consistent path structure for results + run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name + results_dir = os.path.join(self.args.working_dir, 'results', run_name) + figures_dir = os.path.join(results_dir, 'figures') + os.makedirs(figures_dir, exist_ok=True) + + save_path = f"{figures_dir}/lidar_geopath_branch_{i+1}.png" + plt.savefig(save_path, dpi=300) + plt.close() + + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + for net in self.geopath_nets: + if isinstance(net, EMA): + net.update_ema() + + def configure_optimizers(self): + if self.optimizer_name == "adam": + optimizer = torch.optim.Adam( + itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr + ) + elif self.optimizer_name == "adamw": + optimizer = torch.optim.AdamW( + itertools.chain(*[net.parameters() for net in self.geopath_nets]), lr=self.lr + ) + return optimizer diff --git a/src/branchsbm.py b/src/branchsbm.py new file mode 100755 index 0000000000000000000000000000000000000000..734fb43a1f542137b41ee428ede6f295b2dfaf3a --- /dev/null +++ b/src/branchsbm.py @@ -0,0 +1,108 @@ +import sys +import torch +from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, pad_t_like_x +import torch.nn as nn + +class BranchSBM(ConditionalFlowMatcher): + def __init__( + self, geopath_nets: nn.ModuleList = None, alpha: float = 1.0, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.alpha = alpha + self.geopath_nets = geopath_nets + if self.alpha != 0: + assert ( + geopath_nets is not None + ), "GeoPath model must be provided if alpha != 0" + + self.branches = len(geopath_nets) + + def gamma(self, t, t_min, t_max): + return ( + 1.0 + - ((t - t_min) / (t_max - t_min)) ** 2 + - ((t_max - t) / (t_max - t_min)) ** 2 + ) + + def d_gamma(self, t, t_min, t_max): + return 2 * (-2 * t + t_max + t_min) / (t_max - t_min) ** 2 + + def compute_mu_t(self, x0, x1, t, t_min, t_max, branch_idx): + assert branch_idx < self.branches, "Index out of bounds" + + with torch.enable_grad(): + t = pad_t_like_x(t, x0) + if self.alpha == 0: + return (t_max - t) / (t_max - t_min) * x0 + (t - t_min) / ( + t_max - t_min + ) * x1 + + # compute value for specific branch + self.geopath_net_output = self.geopath_nets[branch_idx](x0, x1, t) + if self.geopath_nets[branch_idx].time_geopath: + self.doutput_dt = torch.autograd.grad( + self.geopath_net_output, + t, + grad_outputs=torch.ones_like(self.geopath_net_output), + create_graph=False, + retain_graph=True, + )[0] + return ( + (t_max - t) / (t_max - t_min) * x0 + + (t - t_min) / (t_max - t_min) * x1 + + self.gamma(t, t_min, t_max) * self.geopath_net_output + ) + + def sample_xt(self, x0, x1, t, epsilon, t_min, t_max, branch_idx): + assert branch_idx < self.branches, "Index out of bounds" + mu_t = self.compute_mu_t(x0, x1, t, t_min, t_max, branch_idx) + sigma_t = self.compute_sigma_t(t) + sigma_t = pad_t_like_x(sigma_t, x0) + return mu_t + sigma_t * epsilon + + def sample_location_and_conditional_flow( + self, + x0, + x1, + t_min, + t_max, + branch_idx, + training_geopath_net=False, + midpoint_only=False, + t=None, + ): + + self.training_geopath_net = training_geopath_net + with torch.enable_grad(): + if t is None: + t = torch.rand(x0.shape[0], requires_grad=True) + t = t.type_as(x0) + t = t * (t_max - t_min) + t_min + if midpoint_only: + t = (t_max + t_min) / 2 * torch.ones_like(t).type_as(x0) + + assert len(t) == x0.shape[0], "t has to have batch size dimension" + + eps = self.sample_noise_like(x0) + + # compute xt and ut for branch_idx + xt = self.sample_xt(x0, x1, t, eps, t_min, t_max, branch_idx) + ut = self.compute_conditional_flow(x0, x1, t, xt, t_min, t_max, branch_idx) + + return t, xt, ut + + def compute_conditional_flow(self, x0, x1, t, xt, t_min, t_max, branch_idx): + del xt + t = pad_t_like_x(t, x0) + if self.alpha == 0: + return (x1 - x0) / (t_max - t_min) + + return ( + (x1 - x0) / (t_max - t_min) + + self.d_gamma(t, t_min, t_max) * self.geopath_net_output + + ( + self.gamma(t, t_min, t_max) * self.doutput_dt + if self.geopath_nets[branch_idx].time_geopath + else 0 + ) + ) \ No newline at end of file diff --git a/src/ema.py b/src/ema.py new file mode 100755 index 0000000000000000000000000000000000000000..395e515688c0102d4e3a48aec1eb571eb8357d3f --- /dev/null +++ b/src/ema.py @@ -0,0 +1,64 @@ +import torch + +class EMA(torch.nn.Module): + def __init__(self, model: torch.nn.Module, decay: float = 0.999): + super().__init__() + self.model = model + self.decay = decay + if hasattr(self.model, "time_geopath"): + self.time_geopath = self.model.time_geopath + + # Put this in a buffer so that it gets included in the state dict + self.register_buffer("num_updates", torch.tensor(0)) + + self.shadow_params = torch.nn.ParameterList( + [ + torch.nn.Parameter(p.clone().detach(), requires_grad=False) + for p in model.parameters() + if p.requires_grad + ] + ) + self.backup_params = [] + + def train(self, mode: bool): + if self.training and mode == False: + # Switching from train mode to eval mode. Backup the model parameters and + # overwrite with shadow params + self.backup() + self.copy_to_model() + elif not self.training and mode == True: + # Switching from eval to train mode. Restore the `backup_params` + self.restore_to_model() + + super().train(mode) + + def update_ema(self): + self.num_updates += 1 + num_updates = self.num_updates.item() + decay = min(self.decay, (1 + num_updates) / (10 + num_updates)) + with torch.no_grad(): + params = [p for p in self.model.parameters() if p.requires_grad] + for shadow, param in zip(self.shadow_params, params): + shadow.sub_((1 - decay) * (shadow - param)) + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def copy_to_model(self): + # copy the shadow (ema) parameters to the model + params = [p for p in self.model.parameters() if p.requires_grad] + for shaddow, param in zip(self.shadow_params, params): + param.data.copy_(shaddow.data) + + def backup(self): + # Backup the current model parameters + if len(self.backup_params) > 0: + for p, b in zip(self.model.parameters(), self.backup_params): + b.data.copy_(p.data) + else: + self.backup_params = [param.clone() for param in self.model.parameters()] + + def restore_to_model(self): + # Restores the backed up parameters to the model. + for param, backup in zip(self.model.parameters(), self.backup_params): + param.data.copy_(backup.data) \ No newline at end of file diff --git a/src/geo_metrics/.DS_Store b/src/geo_metrics/.DS_Store new file mode 100755 index 0000000000000000000000000000000000000000..2f8b801b33ce3a0c85c5b39286bfece5d2dd26f9 Binary files /dev/null and b/src/geo_metrics/.DS_Store differ diff --git a/src/geo_metrics/land.py b/src/geo_metrics/land.py new file mode 100755 index 0000000000000000000000000000000000000000..6b01b099b688cfe6dc872b80ce408aae18c92ee4 --- /dev/null +++ b/src/geo_metrics/land.py @@ -0,0 +1,27 @@ +### Adapted from Metric Flow Matching (https://github.com/kkapusniak/metric-flow-matching) + +import torch + +def weighting_function(x, samples, gamma): + pairwise_sq_diff = (x[:, None, :] - samples[None, :, :]) ** 2 + pairwise_sq_dist = pairwise_sq_diff.sum(-1) + weights = torch.exp(-pairwise_sq_dist / (2 * gamma**2)) + return weights + + +def land_metric_tensor(x, samples, gamma, rho): + weights = weighting_function(x, samples, gamma) # Shape [B, N] + differences = samples[None, :, :] - x[:, None, :] # Shape [B, N, D] + squared_differences = differences**2 # Shape [B, N, D] + + # Compute the sum of weighted squared differences for each dimension + M_dd_diag = torch.einsum("bn,bnd->bd", weights, squared_differences) + rho + + # Invert the metric tensor diagonal for each x_t + M_dd_inv_diag = 1.0 / M_dd_diag # Shape [B, D] since it's diagonal + return M_dd_inv_diag + + +def weighting_function_dt(x, dx_dt, samples, gamma, weights): + pairwise_sq_diff_dt = (x[:, None, :] - samples[None, :, :]) * dx_dt[:, None, :] + return -pairwise_sq_diff_dt.sum(-1) * weights / (gamma**2) diff --git a/src/geo_metrics/metric_factory.py b/src/geo_metrics/metric_factory.py new file mode 100755 index 0000000000000000000000000000000000000000..1ad4c086a35484d6bc6e124301a19456d11113a0 --- /dev/null +++ b/src/geo_metrics/metric_factory.py @@ -0,0 +1,102 @@ +### Adapted from Metric Flow Matching (https://github.com/kkapusniak/metric-flow-matching) + +import sys +import torch +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from torch.utils.data import Dataset, DataLoader + +from .land import land_metric_tensor +from .rbf import RBFNetwork + +class DataManifoldMetric: + def __init__( + self, + args, + skipped_time_points=None, + datamodule=None, + ): + self.skipped_time_points = skipped_time_points + self.datamodule = datamodule + + self.gamma = args.gamma_current + self.rho = args.rho + self.metric = args.velocity_metric + self.n_centers = args.n_centers + self.kappa = args.kappa + self.metric_epochs = args.metric_epochs + self.metric_patience = args.metric_patience + self.lr = args.metric_lr + self.alpha_metric = args.alpha_metric + self.image_data = args.data_type == "image" + self.accelerator = args.accelerator + + self.called_first_time = True + self.args = args + + def calculate_metric(self, x_t, samples, current_timestep): + if self.metric == "land": + M_dd_x_t = ( + land_metric_tensor(x_t, samples, self.gamma, self.rho) + ** self.alpha_metric + ) + + elif self.metric == "rbf": + if self.called_first_time: + # Train a single RBF network for all timesteps + print("Learning single RBF network for all timesteps") + self.rbf_network = RBFNetwork( + current_timestep=0, + next_timestep=self.datamodule.num_timesteps - 1, + n_centers=self.n_centers, + kappa=self.kappa, + lr=self.lr, + datamodule=self.datamodule, + args=self.args + ) + early_stop_callback = pl.callbacks.EarlyStopping( + monitor="MetricModel/train_loss_learn_metric_epoch", + patience=self.metric_patience, + mode="min", + ) + trainer = pl.Trainer( + max_epochs=self.metric_epochs, + accelerator=self.accelerator, + logger=WandbLogger(), + num_sanity_val_steps=0, + callbacks=( + [early_stop_callback] if not self.image_data else None + ), + ) + if self.image_data: + self.dataloader = DataLoader( + self.datamodule.all_data, + batch_size=128, + shuffle=True, + ) + trainer.fit(self.rbf_network, self.dataloader) + else: + trainer.fit(self.rbf_network, self.datamodule) + self.called_first_time = False + print("Learning RBF network... Done") + M_dd_x_t = self.rbf_network.compute_metric( + x_t, + epsilon=self.rho, + alpha=self.alpha_metric, + image_hx=self.image_data, + ) + return M_dd_x_t + + def calculate_velocity(self, x_t, u_t, samples, timestep): + + if len(u_t.shape) > 2: + u_t = u_t.reshape(u_t.shape[0], -1) + x_t = x_t.reshape(x_t.shape[0], -1) + M_dd_x_t = self.calculate_metric(x_t, samples, timestep).to(u_t.device) + + # Clamp to prevent NaN from sqrt of negative values when the RBF + # metric tensor is not positive-definite for some inputs + velocity = torch.sqrt(torch.clamp(((u_t**2) * M_dd_x_t).sum(dim=-1), min=0)) + ut_sum = (u_t**2).sum(dim=-1) + metric_sum = M_dd_x_t.sum(dim=-1) + return velocity, ut_sum, metric_sum diff --git a/src/geo_metrics/rbf.py b/src/geo_metrics/rbf.py new file mode 100755 index 0000000000000000000000000000000000000000..fe17a9942ad6e8ce19eda286c54cf58d2e1bf1d5 --- /dev/null +++ b/src/geo_metrics/rbf.py @@ -0,0 +1,161 @@ +### Adapted from Metric Flow Matching (https://github.com/kkapusniak/metric-flow-matching) + +import pytorch_lightning as pl +import torch +from sklearn.cluster import KMeans +import numpy as np + +class RBFNetwork(pl.LightningModule): + def __init__( + self, + current_timestep, + next_timestep, + n_centers: int = 100, + kappa: float = 1.0, + lr=1e-2, + datamodule=None, + image_data=False, + args=None + ): + super().__init__() + self.K = n_centers + self.current_timestep = current_timestep + self.next_timestep = next_timestep + self.clustering_model = KMeans(n_clusters=self.K) + self.kappa = kappa + self.last_val_loss = 1 + self.lr = lr + self.W = torch.nn.Parameter(torch.rand(self.K, 1)) + self.datamodule = datamodule + self.image_data = image_data + self.args = args + + def on_before_zero_grad(self, *args, **kwargs): + self.W.data = torch.clamp(self.W.data, min=0.0001) + + def on_train_start(self): + with torch.no_grad(): + + batch = next(iter(self.trainer.datamodule.train_dataloader())) + + metric_samples = batch[0]["metric_samples"][0] + all_data = torch.cat(metric_samples) + data_to_fit = all_data + + print("Fitting Clustering model...") + self.clustering_model.fit(data_to_fit) + + clusters = ( + self.calculate_centroids(all_data, self.clustering_model.labels_) + if self.image_data + else self.clustering_model.cluster_centers_ + ) + + self.C = torch.tensor(clusters, dtype=torch.float32).to(self.device) + labels = self.clustering_model.labels_ + sigmas = np.zeros((self.K, 1)) + + for k in range(self.K): + points = all_data[labels == k, :] + variance = ((points - clusters[k]) ** 2).mean(axis=0) + sigmas[k, :] = np.sqrt( + variance.sum() if self.image_data else variance.mean() + ) + + # Add small epsilon to prevent division by zero + sigmas = np.maximum(sigmas, 1e-6) + + self.lamda = torch.tensor( + 0.5 / (self.kappa * sigmas) ** 2, dtype=torch.float32 + ).to(self.device) + + def forward(self, x): + if len(x.shape) > 2: + x = x.reshape(x.shape[0], -1).to(self.C.device) + + x = x.to(self.C.device) + dist2 = torch.cdist(x, self.C) ** 2 + self.phi_x = torch.exp(-0.5 * self.lamda[None, :, :] * dist2[:, :, None]) + + h_x = (self.W.to(x.device) * self.phi_x).sum(dim=1) + + return h_x + + def training_step(self, batch, batch_idx): + if self.args.data_type == "scrna" or self.args.data_type == "tahoe": + main_batch = batch[0]["train_samples"][0] + else: + main_batch = batch["train_samples"][0] + + x0 = main_batch["x0"][0] + if self.args.branches == 1: + x1 = main_batch["x1"][0] + inputs = torch.cat([x0, x1], dim=0).to(self.device) + else: + x1_1 = main_batch["x1_1"][0] + x1_2 = main_batch["x1_2"][0] + + inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) + print("inputs shape") + print(inputs.shape) + + loss = ((1 - self.forward(inputs)) ** 2).mean() + self.log( + "MetricModel/train_loss_learn_metric", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + return loss + + def validation_step(self, batch, batch_idx): + if self.args.data_type == "scrna" or self.args.data_type == "tahoe": + main_batch = batch[0]["val_samples"][0] + else: + main_batch = batch["val_samples"][0] + + x0 = main_batch["x0"][0] + if self.args.branches == 1: + x1 = main_batch["x1"][0] + inputs = torch.cat([x0, x1], dim=0).to(self.device) + else: + x1_1 = main_batch["x1_1"][0] + x1_2 = main_batch["x1_2"][0] + + inputs = torch.cat([x0, x1_1, x1_2], dim=0).to(self.device) + + h = self.forward(inputs) + + loss = ((1 - h) ** 2).mean() + self.log( + "MetricModel/val_loss_learn_metric", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.last_val_loss = loss.detach() + return loss + + def calculate_centroids(self, all_data, labels): + unique_labels = np.unique(labels) + centroids = np.zeros((len(unique_labels), all_data.shape[1])) + for i, label in enumerate(unique_labels): + centroids[i] = all_data[labels == label].mean(axis=0) + return centroids + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + return optimizer + + def compute_metric(self, x, alpha=1, epsilon=1e-2, image_hx=False): + if epsilon < 0: + epsilon = (1 - float(self.last_val_loss)) / abs(epsilon) + h_x = self.forward(x) + if image_hx: + h_x = 1 - torch.abs(1 - h_x) + M_x = 1 / (h_x**alpha + epsilon) + else: + M_x = 1 / (h_x + epsilon) ** alpha + return M_x \ No newline at end of file diff --git a/src/losses/.DS_Store b/src/losses/.DS_Store new file mode 100755 index 0000000000000000000000000000000000000000..52f833718427f84cd2672d8a98189c6c96b7051c Binary files /dev/null and b/src/losses/.DS_Store differ diff --git a/src/losses/energy_loss.py b/src/losses/energy_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..dac45da0bb526d52d460bafeaeefe2e10c0ac68a --- /dev/null +++ b/src/losses/energy_loss.py @@ -0,0 +1,74 @@ +import os, math, numpy as np +import torch +import torch.nn as nn +from torchdiffeq import odeint as odeint2 +from torchmetrics.functional import mean_squared_error +import ot + +class EnergySolver(nn.Module): + def __init__(self, flow_net, growth_net, state_cost, data_manifold_metric=None, samples=None, timestep_idx=0): + super(EnergySolver, self).__init__() + self.flow_net = flow_net + self.growth_net = growth_net + self.state_cost = state_cost + + self.data_manifold_metric = data_manifold_metric + self.samples = samples + self.timestep_idx = timestep_idx + + def forward(self, t, state): + xt, wt, mt = state + + xt.requires_grad_(True) + wt.requires_grad_(True) + mt.requires_grad_(True) + + t.requires_grad_(True) + + ut = self.flow_net(t, xt) + gt = self.growth_net(t, xt) + + time=t.expand(xt.shape[0], 1) + time.requires_grad_(True) + + dx_dt = ut + dw_dt = gt + + if self.data_manifold_metric is not None: + vel, _, _ = self.data_manifold_metric.calculate_velocity( + xt, ut, self.samples, self.timestep_idx + ) + + dm_dt = ((vel ** 2).sum(dim =-1) + (gt ** 2)) * wt + else: + dm_dt = ((ut**2).sum(dim =-1) + self.state_cost(xt) + (0.1 * (gt ** 2))) * wt + + assert xt.shape == dx_dt.shape, f"dx mismatch: expected {xt.shape}, got {dx_dt.shape}" + assert wt.shape == dw_dt.shape, f"dw mismatch: expected {wt.shape}, got {dw_dt.shape}" + assert mt.shape == dm_dt.shape, f"dm mismatch: expected {mt.shape}, got {dm_dt.shape}" + return dx_dt, dw_dt, dm_dt + +class ReconsLoss(nn.Module): + def __init__(self, hinge_value=0.01): + super(ReconsLoss, self).__init__() + self.hinge_value = hinge_value + + def __call__(self, source, target, groups = None, to_ignore = None, top_k = 5): + if groups is not None: + # for global loss + c_dist = torch.stack([ + torch.cdist(source[i], target[i]) + + for i in range(1,len(groups)) + if groups[i] != to_ignore + ]) + else: + # for local loss + c_dist = torch.stack([ + torch.cdist(source, target) + ]) + values, _ = torch.topk(c_dist, top_k, dim=2, largest=False, sorted=False) + values -= self.hinge_value + values[values<0] = 0 + loss = torch.mean(values) + return loss \ No newline at end of file diff --git a/src/networks/.DS_Store b/src/networks/.DS_Store new file mode 100755 index 0000000000000000000000000000000000000000..61889c0d1e5593fd406a182718266ba3731301e9 Binary files /dev/null and b/src/networks/.DS_Store differ diff --git a/src/networks/flow_mlp.py b/src/networks/flow_mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..02f0a944882ef18bd673b63a1526349db69b470b --- /dev/null +++ b/src/networks/flow_mlp.py @@ -0,0 +1,17 @@ +import sys +import torch +from .mlp_base import SimpleDenseNet + + +class VelocityNet(SimpleDenseNet): + def __init__(self, dim: int, *args, **kwargs): + super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs) + + def forward(self, t, x): + + if t.dim() < 1 or t.shape[0] != x.shape[0]: + t = t.repeat(x.shape[0])[:, None] + if t.dim() < 2: + t = t[:, None] + x = torch.cat([t, x], dim=-1) + return self.model(x) diff --git a/src/networks/growth_mlp.py b/src/networks/growth_mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..ca48fd457dea8460f4eb438e48c59f4cc7a7da74 --- /dev/null +++ b/src/networks/growth_mlp.py @@ -0,0 +1,36 @@ +import sys +import torch +import torch.nn as nn +from typing import List, Optional + +from .mlp_base import SimpleDenseNet + +class GrowthNet(SimpleDenseNet): + def __init__( + self, + dim: int, + activation: str, + hidden_dims: List[int] = None, + batch_norm: bool = False, + negative: bool = False + ): + super().__init__(input_size=dim + 1, target_size=1, + activation=activation, + batch_norm=batch_norm, + hidden_dims=hidden_dims) + + self.softplus = nn.Softplus() + self.negative = negative + + def forward(self, t, x): + + if t.dim() < 1 or t.shape[0] != x.shape[0]: + t = t.repeat(x.shape[0])[:, None] + if t.dim() < 2: + t = t[:, None] + x = torch.cat([t, x], dim=-1) + x = self.model(x) + x = self.softplus(x) + if self.negative: + x = -x + return x \ No newline at end of file diff --git a/src/networks/interpolant_mlp.py b/src/networks/interpolant_mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..b57d6887c493982ee9ef08c99d6bee6d523080fb --- /dev/null +++ b/src/networks/interpolant_mlp.py @@ -0,0 +1,34 @@ +import sys +import torch +import torch.nn as nn +from typing import List, Optional + +from .mlp_base import SimpleDenseNet + +class GeoPathMLP(nn.Module): + def __init__( + self, + input_dim: int, + activation: str, + batch_norm: bool = True, + hidden_dims: Optional[List[int]] = None, + time_geopath: bool = False, + ): + super().__init__() + self.input_dim = input_dim + self.time_geopath = time_geopath + self.mainnet = SimpleDenseNet( + input_size=2 * input_dim + (1 if time_geopath else 0), + target_size=input_dim, + activation=activation, + batch_norm=batch_norm, + hidden_dims=hidden_dims, + ) + + def forward( + self, x0: torch.Tensor, x1: torch.Tensor, t: torch.Tensor + ) -> torch.Tensor: + x = torch.cat([x0, x1], dim=1) + if self.time_geopath: + x = torch.cat([x, t], dim=1) + return self.mainnet(x) \ No newline at end of file diff --git a/src/networks/mlp_base.py b/src/networks/mlp_base.py new file mode 100755 index 0000000000000000000000000000000000000000..455825d4b4b9d6e5e19b3fff616a7e7ccca7c446 --- /dev/null +++ b/src/networks/mlp_base.py @@ -0,0 +1,45 @@ +import sys +import torch.nn as nn +import torch +from typing import List, Optional + +class swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +ACTIVATION_MAP = { + "relu": nn.ReLU, + "sigmoid": nn.Sigmoid, + "tanh": nn.Tanh, + "selu": nn.SELU, + "elu": nn.ELU, + "lrelu": nn.LeakyReLU, + "softplus": nn.Softplus, + "silu": nn.SiLU, + "swish": swish, +} + + +class SimpleDenseNet(nn.Module): + def __init__( + self, + input_size: int, + target_size: int, + activation: str, + batch_norm: bool = False, + hidden_dims: List[int] = None, + ): + super().__init__() + dims = [input_size, *hidden_dims, target_size] + layers = [] + for i in range(len(dims) - 2): + layers.append(nn.Linear(dims[i], dims[i + 1])) + if batch_norm: + layers.append(nn.BatchNorm1d(dims[i + 1])) + layers.append(ACTIVATION_MAP[activation]()) + layers.append(nn.Linear(dims[-2], dims[-1])) + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) diff --git a/src/networks/utils.py b/src/networks/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..38c0997ba75388cb208284be2983bedebf172b36 --- /dev/null +++ b/src/networks/utils.py @@ -0,0 +1,12 @@ +import sys +import torch + +class flow_model_torch_wrapper(torch.nn.Module): + """Wraps model to torchdyn compatible format.""" + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, t, x, *args, **kwargs): + return self.model(t, x) diff --git a/src/utils.py b/src/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..df59f0d308892d6e425c6bc9565248b359a48270 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,379 @@ +import numpy as np +import torch +import random +import matplotlib +import matplotlib.pyplot as plt +import math +import umap +import scanpy as sc +from sklearn.decomposition import PCA + +import ot as pot +from tqdm import tqdm +from functools import partial +from typing import Optional + +from matplotlib.colors import LinearSegmentedColormap + + +def set_seed(seed): + """ + Sets the seed for reproducibility in PyTorch, Numpy, and Python's Random. + + Parameters: + seed (int): The seed for the random number generators. + """ + random.seed(seed) # Python random module + np.random.seed(seed) # Numpy + torch.manual_seed(seed) # CPU and GPU (deterministic) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) # CUDA + torch.cuda.manual_seed_all(seed) # all GPU devices + torch.backends.cudnn.deterministic = True # CuDNN behavior + torch.backends.cudnn.benchmark = False + + +def wasserstein( + x0: torch.Tensor, + x1: torch.Tensor, + method: Optional[str] = None, + reg: float = 0.05, + power: int = 2, + **kwargs, +) -> float: + assert power == 1 or power == 2 + # ot_fn should take (a, b, M) as arguments where a, b are marginals and + # M is a cost matrix + if method == "exact" or method is None: + ot_fn = pot.emd2 + elif method == "sinkhorn": + ot_fn = partial(pot.sinkhorn2, reg=reg) + else: + raise ValueError(f"Unknown method: {method}") + + a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) + if x0.dim() > 2: + x0 = x0.reshape(x0.shape[0], -1) + if x1.dim() > 2: + x1 = x1.reshape(x1.shape[0], -1) + M = torch.cdist(x0, x1) + if power == 2: + M = M**2 + ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7) + if power == 2: + ret = math.sqrt(ret) + return ret + +min_var_est = 1e-8 + + +# Consider linear time MMD with a linear kernel: +# K(f(x), f(y)) = f(x)^Tf(y) +# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) +# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] +# +# f_of_X: batch_size * k +# f_of_Y: batch_size * k +def linear_mmd2(f_of_X, f_of_Y): + loss = 0.0 + delta = f_of_X - f_of_Y + loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) + return loss + + +# Consider linear time MMD with a polynomial kernel: +# K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d +# f_of_X: batch_size * k +# f_of_Y: batch_size * k +def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): + K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c + K_XX_mean = torch.mean(K_XX.pow(d)) + + K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c + K_YY_mean = torch.mean(K_YY.pow(d)) + + K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c + K_XY_mean = torch.mean(K_XY.pow(d)) + + K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c + K_YX_mean = torch.mean(K_YX.pow(d)) + + return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean + + +def _mix_rbf_kernel(X, Y, sigma_list): + assert X.size(0) == Y.size(0) + m = X.size(0) + + Z = torch.cat((X, Y), 0) + ZZT = torch.mm(Z, Z.t()) + diag_ZZT = torch.diag(ZZT).unsqueeze(1) + Z_norm_sqr = diag_ZZT.expand_as(ZZT) + exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() + + K = 0.0 + for sigma in sigma_list: + gamma = 1.0 / (2 * sigma**2) + K += torch.exp(-gamma * exponent) + + return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) + + +def mix_rbf_mmd2(X, Y, sigma_list, biased=True): + K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) + # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) + return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) + + +def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): + K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) + # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) + return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) + + +################################################################################ +# Helper functions to compute variances based on kernel matrices +################################################################################ + + +def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): + m = K_XX.size(0) # assume X, Y are same shape + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if const_diagonal is not False: + diag_X = diag_Y = const_diagonal + sum_diag_X = sum_diag_Y = m * const_diagonal + else: + diag_X = torch.diag(K_XX) # (m,) + diag_Y = torch.diag(K_YY) # (m,) + sum_diag_X = torch.sum(diag_X) + sum_diag_Y = torch.sum(diag_Y) + + Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X + Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y + K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e + + Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e + Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e + K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e + + if biased: + mmd2 = ( + (Kt_XX_sum + sum_diag_X) / (m * m) + + (Kt_YY_sum + sum_diag_Y) / (m * m) + - 2.0 * K_XY_sum / (m * m) + ) + else: + mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) + + return mmd2 + + +def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): + mmd2, var_est = _mmd2_and_variance( + K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased + ) + loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) + return loss, mmd2, var_est + + +def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): + m = K_XX.size(0) # assume X, Y are same shape + + # Get the various sums of kernels that we'll use + # Kts drop the diagonal, but we don't need to compute them explicitly + if const_diagonal is not False: + diag_X = diag_Y = const_diagonal + sum_diag_X = sum_diag_Y = m * const_diagonal + sum_diag2_X = sum_diag2_Y = m * const_diagonal**2 + else: + diag_X = torch.diag(K_XX) # (m,) + diag_Y = torch.diag(K_YY) # (m,) + sum_diag_X = torch.sum(diag_X) + sum_diag_Y = torch.sum(diag_Y) + sum_diag2_X = diag_X.dot(diag_X) + sum_diag2_Y = diag_Y.dot(diag_Y) + + Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X + Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y + K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e + K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e + + Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e + Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e + K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e + + Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 + Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 + K_XY_2_sum = (K_XY**2).sum() # \| K_{XY} \|_F^2 + + if biased: + mmd2 = ( + (Kt_XX_sum + sum_diag_X) / (m * m) + + (Kt_YY_sum + sum_diag_Y) / (m * m) + - 2.0 * K_XY_sum / (m * m) + ) + else: + mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m) + + var_est = ( + 2.0 + / (m**2 * (m - 1.0) ** 2) + * ( + 2 * Kt_XX_sums.dot(Kt_XX_sums) + - Kt_XX_2_sum + + 2 * Kt_YY_sums.dot(Kt_YY_sums) + - Kt_YY_2_sum + ) + - (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2) + + 4.0 + * (m - 2.0) + / (m**3 * (m - 1.0) ** 2) + * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) + - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum) + - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2 + + 8.0 + / (m**3 * (m - 1.0)) + * ( + 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum + - Kt_XX_sums.dot(K_XY_sums_1) + - Kt_YY_sums.dot(K_XY_sums_0) + ) + ) + return mmd2, var_est + + +def plot_lidar(ax, dataset, xs=None, S=25, branch_idx=None): + # Combine the dataset and trajectory points for sorting + combined_points = [] + combined_colors = [] + combined_sizes = [] + + + custom_colors_1 = ["#05009E", "#A19EFF", "#50B2D7"] + custom_colors_2 = ["#05009E", "#A19EFF", "#D577FF"] + + custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1) + custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2) + + # Normalize the z-coordinates for alpha scaling + z_coords = ( + dataset[:, 2].numpy() if torch.is_tensor(dataset[:, 2]) else dataset[:, 2] + ) + z_min, z_max = z_coords.min(), z_coords.max() + z_norm = (z_coords - z_min) / (z_max - z_min) + + # Add surface points with a lower z-order + for i, point in enumerate(dataset): + grey_value = 0.95 - 0.7 * z_norm[i] + combined_points.append(point.numpy()) + combined_colors.append( + ( + grey_value, + grey_value, + grey_value, + 1.0 + ) + ) # Grey color with transparency + combined_sizes.append(0.1) + + # Add trajectory points with a higher z-order + if xs is not None: + if branch_idx == 0: + cmap = custom_cmap_1 + else: + cmap = custom_cmap_2 + + B, T, D = xs.shape + steps_to_log = np.linspace(0, T - 1, S).astype(int) + xs = xs.cpu().detach().clone() + for idx, step in enumerate(steps_to_log): + for point in xs[:512, step]: + combined_points.append( + point.numpy() if torch.is_tensor(point) else point + ) + combined_colors.append(cmap(idx / (len(steps_to_log) - 1))) + combined_sizes.append(0.8) + + # Convert to numpy array for easier manipulation + combined_points = np.array(combined_points) + combined_colors = np.array(combined_colors) + combined_sizes = np.array(combined_sizes) + + # Sort by z-coordinate (depth) + sorted_indices = np.argsort(combined_points[:, 2]) + combined_points = combined_points[sorted_indices] + combined_colors = combined_colors[sorted_indices] + combined_sizes = combined_sizes[sorted_indices] + + # Plot the sorted points + ax.scatter( + combined_points[:, 0], + combined_points[:, 1], + combined_points[:, 2], + s=combined_sizes, + c=combined_colors, + depthshade=True, + ) + + ax.set_xlim3d(left=-4.8, right=4.8) + ax.set_ylim3d(bottom=-4.8, top=4.8) + ax.set_zlim3d(bottom=0.0, top=2.0) + ax.set_zticks([0, 1.0, 2.0]) + ax.grid(False) + plt.axis("off") + + return ax + + +def plot_images_trajectory(trajectories, vae, processor, num_steps): + + # Compute trajectories for each image + t_span = torch.linspace(0, trajectories.shape[1] - 1, num_steps) + t_span = [int(t) for t in t_span] + num_images = trajectories.shape[0] + + # Decode images at each step in each trajectory + decoded_images = [ + [ + processor.postprocess( + vae.decode( + trajectories[i_image, traj_step].unsqueeze(0) + ).sample.detach() + )[0] + for traj_step in t_span + ] + for i_image in range(num_images) + ] + + # Plotting + fig, axes = plt.subplots( + num_images, num_steps, figsize=(num_steps * 2, num_images * 2) + ) + if num_images == 1: + axes = [axes] # Ensure axes is iterable + for img_idx, img_traj in enumerate(decoded_images): + for step_idx, img in enumerate(img_traj): + ax = axes[img_idx][step_idx] if num_images > 1 else axes[step_idx] + if ( + isinstance(img, np.ndarray) and img.shape[0] == 3 + ): # Assuming 3 channels (RGB) + img = img.transpose(1, 2, 0) + ax.imshow(img) + ax.axis("off") + if img_idx == 0: + ax.set_title(f"t={t_span[step_idx]/t_span[-1]:.2f}") + plt.tight_layout() + return fig + + +def plot_growth(dataset, growth_nets, xs, output_file='plot.pdf'): + x0s = [dataset["x0"][0]] + w0s = [dataset["x0"][1]] + x1s_list = [[dataset["x1_1"][0]], [dataset["x1_2"][0]]] + w1s_list = [[dataset["x1_1"][1]], [dataset["x1_2"][1]]] + + + + plt.show() \ No newline at end of file diff --git a/train.py b/train.py new file mode 100755 index 0000000000000000000000000000000000000000..837658fa7fa6c1020b3ebb70c2a349615bb5690f --- /dev/null +++ b/train.py @@ -0,0 +1,537 @@ +import sys +import os +import argparse +import copy +import time +import json + +import torch.nn as nn +import wandb +from pytorch_lightning import Trainer +from pytorch_lightning.loggers import WandbLogger +from torchcfm.optimal_transport import OTPlanSampler + +from parsers import parse_args +from train_utils import load_config, merge_config, generate_group_string, dataset_name2datapath, create_callbacks +from src.branchsbm import BranchSBM +from src.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar +from src.branch_flow_net_test import ( + FlowNetTestLidar, FlowNetTestMouse, FlowNetTestClonidine, FlowNetTestTrametinib, FlowNetTestVeres +) +from src.branch_interpolant_train import BranchInterpolantTrain +from src.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar, SequentialGrowthNetTrain +from src.networks.flow_mlp import VelocityNet +from src.networks.growth_mlp import GrowthNet +from src.networks.interpolant_mlp import GeoPathMLP +from src.utils import set_seed +from src.ema import EMA +from src.geo_metrics.metric_factory import DataManifoldMetric +from dataloaders.mouse_data import WeightedBranchedCellDataModule, SingleBranchCellDataModule +from dataloaders.three_branch_data import ThreeBranchTahoeDataModule +from dataloaders.clonidine_v2_data import ClonidineV2DataModule +from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule +from dataloaders.trametinib_single import TrametinibSingleBranchDataModule +from dataloaders.lidar_data import WeightedBranchedLidarDataModule +from dataloaders.lidar_data_single import LidarSingleDataModule +from dataloaders.veres_leiden_data import WeightedBranchedVeresDataModule + +def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None: + set_seed(seed) + branches = args.branches + + skipped_time_points = [t_exclude] if t_exclude else [] + print("config path:") + print(args.config_path) + print("whiten") + print(args.whiten) + + # Add date and time prefix to run name for distinguishable results + current_datetime = time.strftime("%m_%d_%H%M", time.localtime()) + run_name_with_datetime = f"{current_datetime}_{args.run_name}" + + # Update args.run_name so test classes use the dated name + args.run_name = run_name_with_datetime + + ### DATAMODULES + + ### DATAMODULES ### + if args.data_name == "lidar": + datamodule = WeightedBranchedLidarDataModule(args=args) + elif args.data_name == "lidarsingle": + datamodule = LidarSingleDataModule(args=args) + elif args.data_name == "mouse": + datamodule = WeightedBranchedCellDataModule(args=args) + elif args.data_name == "mousesingle": + datamodule = SingleBranchCellDataModule(args=args) + elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]: + datamodule = ClonidineV2DataModule(args=args) + elif args.data_name == "clonidine50Dsingle": + datamodule = ClonidineSingleBranchDataModule(args=args) + elif args.data_name == "trametinib": + datamodule = ThreeBranchTahoeDataModule(args=args) + elif args.data_name == "trametinibsingle": + datamodule = TrametinibSingleBranchDataModule(args=args) + elif args.data_name == "veres": + datamodule = WeightedBranchedVeresDataModule(args=args) + branches = datamodule.num_branches + print("number of branches:", branches) + + flow_nets = nn.ModuleList() + geopath_nets = nn.ModuleList() + growth_nets = nn.ModuleList() + + ##### initialize branched flow and growth networks ##### + for i in range(branches): + flow_net = VelocityNet( + dim=args.dim, + hidden_dims=args.hidden_dims_flow, + activation=args.activation_flow, + batch_norm=False, + ) + geopath_net = GeoPathMLP( + input_dim=args.dim, + hidden_dims=args.hidden_dims_geopath, + time_geopath=args.time_geopath, + activation=args.activation_geopath, + batch_norm=False, + ) + + if i == 0: + growth_net = GrowthNet( + dim=args.dim, + hidden_dims=args.hidden_dims_growth, + activation=args.activation_growth, + batch_norm=False, + negative=True + ) + else: + growth_net = GrowthNet( + dim=args.dim, + hidden_dims=args.hidden_dims_growth, + activation=args.activation_growth, + batch_norm=False, + negative=False + ) + + if args.ema_decay is not None: + flow_net = EMA(model=flow_net, decay=args.ema_decay) + geopath_net = EMA(model=geopath_net, decay=args.ema_decay) + growth_net = EMA(model=growth_net, decay=args.ema_decay) + + flow_nets.append(flow_net) + geopath_nets.append(geopath_net) + growth_nets.append(growth_net) + + + ot_sampler = ( + OTPlanSampler(method=args.optimal_transport_method) + if args.optimal_transport_method != "None" + else None + ) + + wandb.init( + project="branchsbm", + name=run_name_with_datetime, + config=vars(args), + dir=args.working_dir, + ) + + flow_matcher_base = BranchSBM( + geopath_nets=geopath_nets, + sigma=args.sigma, + alpha=int(args.branchsbm), + ) + + ##### STAGE 1: Training of Geodesic Interpolants Beginning ##### + geopath_callbacks = create_callbacks( + args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id + ) + + # define state cost + data_manifold_metric = DataManifoldMetric( + args=args, + skipped_time_points=skipped_time_points, + datamodule=datamodule, + ) + geopath_model = BranchInterpolantTrain( + flow_matcher=flow_matcher_base, + skipped_time_points=skipped_time_points, + ot_sampler=ot_sampler, + args=args, + data_manifold_metric=data_manifold_metric + ) + + wandb_logger = WandbLogger(version=run_name_with_datetime) + + trainer = Trainer( + max_epochs=args.epochs, + callbacks=geopath_callbacks, + accelerator=args.accelerator, + logger=wandb_logger, + num_sanity_val_steps=0, + default_root_dir=args.working_dir, + gradient_clip_val=(1.0 if args.data_type == "image" else None), + ) + + if args.load_geopath_model_ckpt: + best_model_path = args.load_geopath_model_ckpt + else: + trainer.fit( + geopath_model, + datamodule=datamodule, + ) + + best_model_path = geopath_callbacks[0].best_model_path + + geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path) + + flow_matcher_base.geopath_nets = geopath_model.geopath_nets + + ##### STAGE 1: Training of Geodesic Interpolants End ##### + + ##### STAGE 2: Flow Matching Beginning ##### + flow_callbacks = create_callbacks( + args, + phase="flow", + data_type=args.data_type, + run_id=wandb.run.id, + datamodule=datamodule, + ) + + if args.data_type == "lidar": + FlowNetTrain = FlowNetTrainLidar + else: + FlowNetTrain = FlowNetTrainCell + + flow_train = FlowNetTrain( + flow_matcher=flow_matcher_base, + flow_nets=flow_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + ) + + # Reuse existing wandb run from Stage 1 + wandb_logger = WandbLogger(version=run_name_with_datetime) + + trainer = Trainer( + max_epochs=args.epochs, + callbacks=flow_callbacks, + check_val_every_n_epoch=args.check_val_every_n_epoch, + accelerator=args.accelerator, + logger=wandb_logger, + default_root_dir=args.working_dir, + gradient_clip_val=(1.0 if args.data_type == "image" else None), + num_sanity_val_steps=(0 if args.data_type == "image" else None), + ) + + trainer.fit( + flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt + ) + if args.data_type == "lidar": + trainer.test(flow_train, datamodule=datamodule) + ##### STAGE 2: Flow Matching End ##### + + ##### STAGE 3: Training Growth Networks Beginning #### + flow_nets = flow_train.flow_nets + + growth_callbacks = create_callbacks( + args, + phase="growth", + data_type=args.data_type, + run_id=wandb.run.id, + datamodule=datamodule, + ) + + if args.data_type == "lidar": + GrowthNetTrainClass = GrowthNetTrainLidar + else: + GrowthNetTrainClass = GrowthNetTrainCell + + growth_train = GrowthNetTrainClass( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = False + ) + + # Reuse existing wandb run + wandb_logger = WandbLogger(version=run_name_with_datetime) + + trainer = Trainer( + max_epochs=args.epochs, + callbacks=growth_callbacks, + check_val_every_n_epoch=args.check_val_every_n_epoch, + accelerator=args.accelerator, + logger=wandb_logger, + default_root_dir=args.working_dir, + gradient_clip_val=(1.0 if args.data_type == "image" else None), + num_sanity_val_steps=(0 if args.data_type == "image" else None), + ) + + trainer.fit( + growth_train, datamodule=datamodule, ckpt_path=None + ) + + # Load best checkpoint for testing + best_growth_path = growth_callbacks[0].best_model_path + if best_growth_path: + print(f"Loading best growth model from: {best_growth_path}") + if args.sequential: + growth_train = SequentialGrowthNetTrain.load_from_checkpoint( + best_growth_path, + flow_nets=flow_nets, + growth_nets=growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint=False + ) + else: + growth_train = GrowthNetTrainClass.load_from_checkpoint( + best_growth_path, + flow_nets=flow_nets, + growth_nets=growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint=False + ) + # Extract the trained flow_nets from the loaded checkpoint + flow_nets = growth_train.flow_nets + # Ensure flow_nets and growth_nets are ModuleList (not tuple) + if isinstance(flow_nets, tuple): + flow_nets = nn.ModuleList(flow_nets) + if isinstance(growth_nets, tuple): + growth_nets = nn.ModuleList(growth_nets) + + # Use appropriate test class based on data type + if "lidar" in args.data_name.lower(): + test_model = FlowNetTestLidar( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = False + ) + elif "mouse" in args.data_name.lower(): + test_model = FlowNetTestMouse( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = False + ) + elif "clonidine" in args.data_name.lower(): + test_model = FlowNetTestClonidine( + flow_matcher=flow_matcher_base, + flow_nets=flow_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + ) + elif "trametinib" in args.data_name.lower(): + test_model = FlowNetTestTrametinib( + flow_matcher=flow_matcher_base, + flow_nets=flow_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + ) + elif "veres" in args.data_name.lower(): + test_model = FlowNetTestVeres( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = False + ) + else: + # Default to growth_train test + test_model = growth_train + + trainer.test(test_model, datamodule=datamodule) + + ##### STAGE 3: Training Growth Networks End #### + + ##### STAGE 4: Joint Training Beginning #### + + growth_nets = growth_train.growth_nets + + joint_callbacks = create_callbacks( + args, + phase="joint", + data_type=args.data_type, + run_id=wandb.run.id, + datamodule=datamodule, + ) + + if args.sequential: + joint_train = SequentialGrowthNetTrain( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = True + ) + else: + if args.data_type == "lidar": + GrowthNetTrainClass = GrowthNetTrainLidar + else: + GrowthNetTrainClass = GrowthNetTrainCell + + joint_train = GrowthNetTrainClass( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = True + ) + + # Reuse existing wandb run + wandb_logger = WandbLogger(version=run_name_with_datetime) + + trainer = Trainer( + max_epochs=args.epochs, + callbacks=joint_callbacks, + check_val_every_n_epoch=args.check_val_every_n_epoch, + accelerator=args.accelerator, + logger=wandb_logger, + default_root_dir=args.working_dir, + gradient_clip_val=(1.0 if args.data_type == "image" else None), + num_sanity_val_steps=(0 if args.data_type == "image" else None), + ) + + trainer.fit( + joint_train, datamodule=datamodule, ckpt_path=None + ) + + # Load best checkpoint for testing + best_joint_path = joint_callbacks[0].best_model_path + if best_joint_path: + print(f"Loading best joint model from: {best_joint_path}") + if args.sequential: + joint_train = SequentialGrowthNetTrain.load_from_checkpoint( + best_joint_path, + flow_nets=flow_nets, + growth_nets=growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint=True + ) + else: + joint_train = GrowthNetTrainClass.load_from_checkpoint( + best_joint_path, + flow_nets=flow_nets, + growth_nets=growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint=True + ) + # Extract the trained flow_nets and growth_nets from the loaded checkpoint + flow_nets = joint_train.flow_nets + growth_nets = joint_train.growth_nets + # Ensure flow_nets and growth_nets are ModuleList (not tuple) + if isinstance(flow_nets, tuple): + flow_nets = nn.ModuleList(flow_nets) + if isinstance(growth_nets, tuple): + growth_nets = nn.ModuleList(growth_nets) + + # Use appropriate test class based on data type + if "lidar" in args.data_name.lower(): + test_model = FlowNetTestLidar( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = True + ) + elif "mouse" in args.data_name.lower(): + test_model = FlowNetTestMouse( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = True + ) + elif "clonidine" in args.data_name.lower(): + test_model = FlowNetTestClonidine( + flow_matcher=flow_matcher_base, + flow_nets=flow_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + ) + elif "trametinib" in args.data_name.lower(): + test_model = FlowNetTestTrametinib( + flow_matcher=flow_matcher_base, + flow_nets=flow_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + ) + elif "veres" in args.data_name.lower(): + test_model = FlowNetTestVeres( + flow_nets = flow_nets, + growth_nets = growth_nets, + ot_sampler=ot_sampler, + skipped_time_points=skipped_time_points, + args=args, + data_manifold_metric=data_manifold_metric, + joint = True + ) + else: + test_model = joint_train + test_model = joint_train + + trainer.test(test_model, datamodule=datamodule) + + ##### STAGE 4: Joint Training End #### + + wandb.finish() + +if __name__ == "__main__": + args = parse_args() + updated_args = copy.deepcopy(args) + if args.config_path: + config = load_config(args.config_path) + updated_args = merge_config(updated_args, config) + + updated_args.group_name = generate_group_string() + updated_args.data_path = dataset_name2datapath( + updated_args.data_name, updated_args.working_dir + ) + for seed in updated_args.seeds: + if updated_args.t_exclude: + for i, t_exclude in enumerate(updated_args.t_exclude): + updated_args.t_exclude_current = t_exclude + updated_args.seed_current = seed + updated_args.gamma_current = updated_args.gammas[i] + main(updated_args, seed=seed, t_exclude=t_exclude) + else: + updated_args.seed_current = seed + updated_args.gamma_current = updated_args.gammas[0] + main(updated_args, seed=seed, t_exclude=None) diff --git a/train_utils.py b/train_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..c6660d325039039fefac2a682e7483117074683c --- /dev/null +++ b/train_utils.py @@ -0,0 +1,156 @@ +import sys +import yaml +import string +import secrets +import os +import torch +import wandb +from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint +from torchdyn.core import NeuralODE +from src.utils import plot_images_trajectory +from src.networks.utils import flow_model_torch_wrapper + + +def load_config(path): + with open(path, "r") as file: + config = yaml.safe_load(file) + return config + + +def merge_config(args, config_updates): + for key, value in config_updates.items(): + if not hasattr(args, key): + raise ValueError( + f"Unknown configuration parameter '{key}' found in the config file." + ) + setattr(args, key, value) + return args + +def generate_group_string(length=16): + alphabet = string.ascii_letters + string.digits + return "".join(secrets.choice(alphabet) for _ in range(length)) + + +def dataset_name2datapath(dataset_name, working_dir): + if dataset_name in ["lidar", "lidarsingle"]: + return os.path.join(working_dir, "data", "rainier2-thin.las") + elif dataset_name in ["mouse", "mousesingle"]: + return os.path.join(working_dir, "data", "mouse_hematopoiesis.csv") + elif dataset_name in ["clonidine50D", "clonidine100D", "clonidine150D", "clonidine50Dsingle", "clonidine100Dsingle", "clonidine150Dsingle"]: + return os.path.join(working_dir, "data", "pca_and_leiden_labels.csv") + elif dataset_name in ["trametinib", "trametinibsingle"]: + return os.path.join(working_dir, "data", "Trametinib_5.0uM_pca_and_leidenumap_labels.csv") + elif dataset_name in ["veres", "veressingle"]: + return os.path.join(working_dir, "data", "Veres_alltime.csv") + else: + raise ValueError("Dataset not recognized") + + +def create_callbacks(args, phase, data_type, run_id, datamodule=None): + + dirpath = os.path.join( + args.working_dir, + "checkpoints", + data_type, + str(args.run_name), + f"{phase}_model", + ) + + if phase == "geopath": + early_stop_callback = EarlyStopping( + monitor="BranchPathNet/train_loss_geopath_epoch", + patience=args.patience_geopath, + mode="min", + ) + checkpoint_callback = ModelCheckpoint( + dirpath=dirpath, + monitor="BranchPathNet/train_loss_geopath_epoch", + mode="min", + save_top_k=1, + ) + callbacks = [checkpoint_callback, early_stop_callback] + elif phase == "flow": + early_stop_callback = EarlyStopping( + monitor="FlowNet/train_loss_cfm", + patience=args.patience, + mode="min", + ) + checkpoint_callback = ModelCheckpoint( + dirpath=dirpath, + monitor="FlowNet/train_loss_cfm", + mode="min", + save_top_k=1, + ) + callbacks = [checkpoint_callback, early_stop_callback] + elif phase == "growth": + early_stop_callback = EarlyStopping( + monitor="GrowthNet/train_loss", + patience=args.patience, + mode="min", + ) + checkpoint_callback = ModelCheckpoint( + dirpath=dirpath, + monitor="GrowthNet/train_loss", + mode="min", + save_top_k=1, + ) + callbacks = [checkpoint_callback, early_stop_callback] + elif phase == "joint": + early_stop_callback = EarlyStopping( + monitor="JointTrain/train_loss", + patience=args.patience, + mode="min", + ) + checkpoint_callback = ModelCheckpoint( + dirpath=dirpath, + mode="min", + save_top_k=1, + ) + callbacks = [checkpoint_callback, early_stop_callback] + else: + raise ValueError("Unknown phase") + return callbacks + + +class PlottingCallback(Callback): + def __init__(self, plot_interval, datamodule): + self.plot_interval = plot_interval + self.datamodule = datamodule + + def on_train_epoch_end(self, trainer, pl_module): + epoch = trainer.current_epoch + pl_module.flow_net.train(mode=False) + if epoch % self.plot_interval == 0 and epoch != 0: + node = NeuralODE( + flow_model_torch_wrapper(pl_module.flow_net).to(self.datamodule.device), + solver="tsit5", + sensitivity="adjoint", + atol=1e-5, + rtol=1e-5, + ) + + for mode in ["train", "val"]: + x0 = getattr(self.datamodule, f"{mode}_x0") + x0 = x0[0:15] + fig = self.trajectory_and_plot(x0, node, self.datamodule) + wandb.log({f"Trajectories {mode.capitalize()}": wandb.Image(fig)}) + pl_module.flow_net.train(mode=True) + + def trajectory_and_plot(self, x0, node, datamodule): + selected_images = x0[0:15] + with torch.no_grad(): + traj = node.trajectory( + selected_images.to(datamodule.device), + t_span=torch.linspace(0, 1, 100).to(datamodule.device), + ) + + traj = traj.transpose(0, 1) + traj = traj.reshape(*traj.shape[0:2], *datamodule.dim) + + fig = plot_images_trajectory( + traj.to(datamodule.device), + datamodule.vae.to(datamodule.device), + datamodule.process, + num_steps=5, + ) + return fig \ No newline at end of file