Sophia Tang commited on
Commit
b55bace
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +61 -0
  2. .gitignore +15 -0
  3. LICENSE +21 -0
  4. README.md +68 -0
  5. assets/branchsbm.png +3 -0
  6. assets/branchsbm_anim.gif +3 -0
  7. assets/clonidine.png +3 -0
  8. assets/lidar.png +3 -0
  9. assets/mouse.png +3 -0
  10. assets/trametinib.png +3 -0
  11. assets/veres.png +3 -0
  12. configs/.DS_Store +0 -0
  13. configs/clonidine_100D.yaml +22 -0
  14. configs/clonidine_150D.yaml +22 -0
  15. configs/clonidine_50D.yaml +22 -0
  16. configs/clonidine_50Dsingle.yaml +22 -0
  17. configs/lidar.yaml +15 -0
  18. configs/lidar_single.yaml +15 -0
  19. configs/mouse.yaml +18 -0
  20. configs/mouse_single.yaml +18 -0
  21. configs/trametinib.yaml +22 -0
  22. configs/trametinib_single.yaml +22 -0
  23. configs/veres.yaml +25 -0
  24. dataloaders/.DS_Store +0 -0
  25. dataloaders/clonidine_single_branch.py +265 -0
  26. dataloaders/clonidine_v2_data.py +280 -0
  27. dataloaders/lidar_data.py +529 -0
  28. dataloaders/lidar_data_single.py +274 -0
  29. dataloaders/mouse_data.py +453 -0
  30. dataloaders/three_branch_data.py +306 -0
  31. dataloaders/trametinib_single.py +268 -0
  32. dataloaders/veres_leiden_data.py +317 -0
  33. environment.yml +41 -0
  34. parsers.py +502 -0
  35. scripts/README.md +226 -0
  36. scripts/clonidine100.sh +26 -0
  37. scripts/clonidine150.sh +26 -0
  38. scripts/clonidine50.sh +26 -0
  39. scripts/clonidine50_single.sh +26 -0
  40. scripts/lidar.sh +26 -0
  41. scripts/lidar_single.sh +27 -0
  42. scripts/mouse.sh +25 -0
  43. scripts/mouse_single.sh +25 -0
  44. scripts/trametinib.sh +26 -0
  45. scripts/trametinib_single.sh +26 -0
  46. scripts/veres.sh +26 -0
  47. src/.DS_Store +0 -0
  48. src/branch_flow_net_test.py +1791 -0
  49. src/branch_flow_net_train.py +375 -0
  50. src/branch_growth_net_train.py +994 -0
.gitattributes ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ branchsbm.png filter=lfs diff=lfs merge=lfs -text
37
+ branchsbm/branchsbm.png filter=lfs diff=lfs merge=lfs -text
38
+ branchsbm/clonidine.png filter=lfs diff=lfs merge=lfs -text
39
+ branchsbm/lidar.png filter=lfs diff=lfs merge=lfs -text
40
+ branchsbm/mouse.png filter=lfs diff=lfs merge=lfs -text
41
+ branchsbm/trametinib.png filter=lfs diff=lfs merge=lfs -text
42
+ clonidine.png filter=lfs diff=lfs merge=lfs -text
43
+ lidar.png filter=lfs diff=lfs merge=lfs -text
44
+ mouse.png filter=lfs diff=lfs merge=lfs -text
45
+ trametinib.png filter=lfs diff=lfs merge=lfs -text
46
+ data/pca_and_leiden_labels.csv filter=lfs diff=lfs merge=lfs -text
47
+ data/mouse_hematopoiesis.csv filter=lfs diff=lfs merge=lfs -text
48
+ data/simulation_gene.csv filter=lfs diff=lfs merge=lfs -text
49
+ data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv filter=lfs diff=lfs merge=lfs -text
50
+ data/Veres_alltime.csv filter=lfs diff=lfs merge=lfs -text
51
+ data/Weinreb_alltime.csv filter=lfs diff=lfs merge=lfs -text
52
+ data/Weinreb_t2_leiden_clusters.csv filter=lfs diff=lfs merge=lfs -text
53
+ data/eb_noscale.csv filter=lfs diff=lfs merge=lfs -text
54
+ data/emt.csv filter=lfs diff=lfs merge=lfs -text
55
+ *.csv filter=lfs diff=lfs merge=lfs -text
56
+ data/*.las filter=lfs diff=lfs merge=lfs -text
57
+ *.csv filter=lfs diff=lfs merge=lfs -text
58
+ data/*.las filter=lfs diff=lfs merge=lfs -text
59
+ *.png filter=lfs diff=lfs merge=lfs -text
60
+ assets/veres.png filter=lfs diff=lfs merge=lfs -text
61
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ logs/
2
+ wandb/
3
+ __pycache__/
4
+ checkpoints/
5
+ lightining_logs/
6
+ results/
7
+ *.log
8
+ *.pyc
9
+ lightining_logs/
10
+ figures/
11
+ *.ckpt
12
+ *.csv
13
+ data/
14
+ extra/
15
+ .vscode/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sophia Tang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # [Branched Schrödinger Bridge Matching](https://arxiv.org/abs/2506.09007) (ICLR 2026) 🌳🧬
2
+
3
+ [**Sophia Tang**](https://sophtang.github.io/), [**Yinuo Zhang**](https://www.linkedin.com/in/yinuozhang98/), [**Alexander Tong**](https://www.alextong.net/) and [**Pranam Chatterjee**](https://www.chatterjeelab.com/)
4
+
5
+ ![BranchSBM](assets/branchsbm_anim.gif)
6
+
7
+ This is the repository for [**Branched Schrödinger Bridge Matching**](https://arxiv.org/abs/2506.09007) (ICLR 2026) 🌳🧬. It is partially built on the [**Metric Flow Matching repo**](https://github.com/kkapusniak/metric-flow-matching) ([Kapusniak et al., 2024](https://arxiv.org/abs/2405.14780)).
8
+
9
+ Predicting how a population evolves between an initial and final state is central to many problems in generative modeling, from simulating perturbation responses to modelling cell fate decisions 🧫. Existing approaches, such as flow matching and Schrödinger Bridge Matching, effectively learn mappings between two distributions by modelling a single stochastic path. However, these methods are **inherently limited to unimodal transitions and cannot capture branched or divergent evolution from a common origin to multiple distinct outcomes.**
10
+
11
+ A key challenge in trajectory matching is reconstructing multi-modal marginals, particularly when modes diverge along distinct dynamical paths . Existing Schrödinger bridge and flow matching frameworks approximate multi-modal distributions by simulating many *independent* particle trajectories, which are susceptible to mode collapse, with particles concentrating on dominant high-density modes or traversing only low-energy intermediate paths.
12
+
13
+ To address this, we introduce **Branched Schrödinger Bridge Matching (BranchSBM)** 🌳🧬, a novel framework that learns a set of diverging velocity fields to reconstruct multi-modal target distributions while simultaneously learning growth networks that allocate mass across branches. Guided by a time-dependent potential energy function Vt, BranchSBM captures diverging, energy-minimizing dynamics without requiring intermediate-time supervision and can generate the full branched evolution from a single initial sample.
14
+
15
+ 🌟 We define the **Branched Generalized Schrödinger Bridge problem** and introduce BranchSBM, a novel matching framework that learns optimal branched trajectories from an initial distribution to multiple target distributions.
16
+
17
+ 🌟 We derive the Branched Conditional Stochastic Optimal Control (CondSOC) problem as the sum of Unbalanced CondSOC objectives and leverage a multi-stage training algorithm to learn the optimal branching drift and growth fields that transport mass along a branched trajectory.
18
+
19
+ 🌟 We demonstrate the unique capability of BranchSBM to model dynamic branching trajectories across various real-world problems, including 3D navigation over LiDAR manifolds, modelling differentiating single-cell population dynamics, and simulating heterogeneous cellular responses to drug perturbation.
20
+
21
+ # Experiments
22
+ Code and instructions to reproduce our results are provided in `/scripts/README`.
23
+
24
+ ## LiDAR Experiment 🗻
25
+
26
+ As a proof of concept, we first evaluate BranchSBM for navigating branched paths along the surface of a three-dimensional LiDAR manifold, from an initial distribution to two distinct target distributions while remaining on low-altitude regions of the manifold.
27
+
28
+ ![LiDAR Experiment](assets/lidar.png)
29
+
30
+ ## Mouse Hematopoiesis and Pancreatic β-Cell Experiment 🧫
31
+
32
+ BranchSBM is uniquely positioned to model single-cell population dynamics where a homogeneous cell population (e.g., progenitor cells) differentiates into several distinct subpopulation branches, each of which independently undergoes growth dynamics. In this experiment, we demonstrate this capability on mouse hematopoiesis data and pancreatic β-cell differentiation data.
33
+
34
+ We evaluate BranchSBM on a mouse hematopoiesis scRNA-seq dataset containing three developmental time points representing progenitor cells differentiating into two terminal cell fates. Compared to a single-branch SBM, BranchSBM successfully learns distinct branching trajectories and accurately reconstructs intermediate cell states, demonstrating its ability to recover lineage bifurcation dynamics.
35
+
36
+ ![Mouse Experiment](assets/mouse.png)
37
+
38
+ We evaluate BranchSBM on a pancreatic β-cell differentiation dataset ([Veres et al., 2019](https://www.nature.com/articles/s41586-019-1168-5)) containing 51,274 cells collected across eight time points as human pluripotent stem cells differentiate into pancreatic β-like cells. Cells are projected into a 30-dimensional PCA space, and Leiden clustering is used to define 11 terminal cell populations at the final time point.
39
+
40
+ BranchSBM is trained using only samples from the initial and final states, while intermediate distributions are inferred by learning trajectories constrained to the data manifold using an RBF state cost. Compared to baselines, BranchSBM significantly improves reconstruction of both intermediate and terminal distributions, achieving lower Wasserstein distances at validation time points. These results demonstrate that BranchSBM can accurately recover branching differentiation dynamics without intermediate supervision.
41
+
42
+ ![Veres Experiment](assets/veres.png)
43
+
44
+ ## Cell Perturbation Modelling Experiment 💉
45
+ Predicting the effects of perturbation on cell state dynamics is a crucial problem for therapeutic design. In this experiment, we leverage BranchSBM to model the **trajectories of a single cell line from a single homogeneous state to multiple heterogeneous states after a drug-induced perturbation**. We demonstrate that BranchSBM is capable of modeling high-dimensional gene expression data and learning branched trajectories that accurately reconstruct diverging perturbed cell populations.
46
+
47
+ We extract the data for a single cell line (A-549) under perturbation with Clonidine and Trametinib at 5 µL, selected based on cell abundance and response diversity from the Tahoe-100M dataset.
48
+
49
+ For the Clonidine perturbation data, we show that **BranchSBM reconstructs the ground-truth distributions, capturing the location and spread of the dataset**, whereas single-branch SBM fails to differentiate cells in cluster 1 that differ from cluster 0 in higher-dimensional principal components. We also show that BranchSBM can simulate trajectories in high-dimensional state spaces by *scaling up to 150 PCs*.
50
+
51
+ ![Clonidine Experiment](assets/clonidine.png)
52
+
53
+ We further show that BranchSBM can **scale beyond two branches by modeling the perturbed cell population of Trametinib-treated cells**, which diverge into *three distinct clusters*. We trained BranchSBM with three endpoints and single-branch SBM with one endpoint containing all three clusters on the top 50 PCs.
54
+
55
+ ![Trametinib Experiment](assets/trametinib.png)
56
+
57
+
58
+ ## Citation
59
+ If you find this repository helpful for your publications, please consider citing our paper:
60
+ ```
61
+ @article{tang2026branchsbm,
62
+ title={Branched Schrödinger Bridge Matching},
63
+ author={Tang, Sophia and Zhang, Yinuo and Tong, Alexander and Chatterjee, Pranam},
64
+ journal={14th International Conference on Learning Representations (ICLR 2026)},
65
+ year={2026}
66
+ }
67
+ ```
68
+ To use this repository, you agree to abide by the MIT License.
assets/branchsbm.png ADDED

Git LFS Details

  • SHA256: a761a2158f833de59d96c2319e33b1197795a0bb6589de549316b34edd30be6c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.71 MB
assets/branchsbm_anim.gif ADDED

Git LFS Details

  • SHA256: c8664b6455486dc347242cda8a84e951ef4f5e380bd2746f3a415de82cf919fc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.49 MB
assets/clonidine.png ADDED

Git LFS Details

  • SHA256: b3b84a92120765db993d20bb39bbd56123a7d787aff95e6ee2e1799c44ebdc30
  • Pointer size: 133 Bytes
  • Size of remote file: 12.9 MB
assets/lidar.png ADDED

Git LFS Details

  • SHA256: 20160a1b105aea66d8305ceaead754e2e37c22595b711208652d7059fc76955a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.98 MB
assets/mouse.png ADDED

Git LFS Details

  • SHA256: 725f5df7f9d9d2e50200030a6d927ca0e01247361875f60fd26ea01be33c08ef
  • Pointer size: 132 Bytes
  • Size of remote file: 8.74 MB
assets/trametinib.png ADDED

Git LFS Details

  • SHA256: 18e425c06dd2c2f3bc7fa12386d24f3b9e822b0bb76cc2c377baa13d3db5e82d
  • Pointer size: 132 Bytes
  • Size of remote file: 6.05 MB
assets/veres.png ADDED

Git LFS Details

  • SHA256: 64dd688eff9c4c242bc267c617d327b801c2a03d7f91f9d30813b379da6e8fd1
  • Pointer size: 132 Bytes
  • Size of remote file: 4.64 MB
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/clonidine_100D.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine100D"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 100
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 300
15
+ kappa: 2
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 100
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 2
22
+ metric_clusters: 3
configs/clonidine_150D.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine150D"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 150
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 300
15
+ kappa: 3
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 100
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 2
22
+ metric_clusters: 3
configs/clonidine_50D.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine50D"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 100
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 2
22
+ metric_clusters: 3
configs/clonidine_50Dsingle.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "clonidine50Dsingle"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 100
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 1
22
+ metric_clusters: 2
configs/lidar.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "lidar"
2
+ data_name: "lidar"
3
+ dim: 3
4
+ whiten: true
5
+ t_exclude: []
6
+ velocity_metric: "land"
7
+ gammas: [0.125]
8
+ rho: 0.001
9
+ branchsbm: true
10
+ seeds: [42]
11
+ patience_geopath: 50
12
+ metric_epochs: 100
13
+ time_geopath: true
14
+ branches: 2
15
+ metric_clusters: 3
configs/lidar_single.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "lidar"
2
+ data_name: "lidarsingle"
3
+ dim: 3
4
+ whiten: true
5
+ t_exclude: []
6
+ velocity_metric: "land"
7
+ gammas: [0.125]
8
+ rho: 0.001
9
+ branchsbm: true
10
+ seeds: [42]
11
+ patience_geopath: 50
12
+ metric_epochs: 100
13
+ time_geopath: true
14
+ branches: 1
15
+ metric_clusters: 2
configs/mouse.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "scrna"
2
+ data_name: "mouse"
3
+ hidden_dims_geopath: [64, 64, 64]
4
+ hidden_dims_flow: [64, 64, 64]
5
+ hidden_dims_growth: [64, 64, 64]
6
+ dim: 2
7
+ whiten: false
8
+ t_exclude: []
9
+ velocity_metric: "land"
10
+ gammas: [0.125]
11
+ rho: 0.001
12
+ branchsbm: true
13
+ seeds: [42]
14
+ patience_geopath: 50
15
+ metric_epochs: 100
16
+ time_geopath: false
17
+ branches: 2
18
+ metric_clusters: 2
configs/mouse_single.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "scrna"
2
+ data_name: "mousesingle"
3
+ hidden_dims_geopath: [64, 64, 64]
4
+ hidden_dims_flow: [64, 64, 64]
5
+ hidden_dims_growth: [64, 64, 64]
6
+ dim: 2
7
+ whiten: false
8
+ t_exclude: []
9
+ velocity_metric: "land"
10
+ gammas: [0.125]
11
+ rho: 0.001
12
+ branchsbm: true
13
+ seeds: [42]
14
+ patience_geopath: 50
15
+ metric_epochs: 100
16
+ time_geopath: true
17
+ branches: 1
18
+ metric_clusters: 2
configs/trametinib.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "trametinib"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 100
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 3
22
+ metric_clusters: 4
configs/trametinib_single.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "tahoe"
2
+ data_name: "trametinibsingle"
3
+ accelerator: "gpu"
4
+ hidden_dims_geopath: [1024, 1024, 1024]
5
+ hidden_dims_flow: [1024, 1024, 1024]
6
+ hidden_dims_growth: [1024, 1024, 1024]
7
+ dim: 50
8
+ t_exclude: []
9
+ time_geopath: true
10
+ whiten: false
11
+ velocity_metric: "rbf"
12
+ metric_patience: 25
13
+ patience: 25
14
+ n_centers: 150
15
+ kappa: 1.5
16
+ rho: -2.75
17
+ alpha_metric: 1
18
+ metric_epochs: 100
19
+ branchsbm: true
20
+ seeds: [42]
21
+ branches: 1
22
+ metric_clusters: 2
configs/veres.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_type: "scrna"
2
+ data_name: "veres"
3
+ data_path: "data/Veres_alltime.csv"
4
+ accelerator: "gpu"
5
+ hidden_dims_geopath: [512, 512, 512]
6
+ hidden_dims_flow: [512, 512, 512]
7
+ hidden_dims_growth: [512, 512, 512]
8
+ dim: 30
9
+ t_exclude: []
10
+ time_geopath: true
11
+ whiten: false
12
+ velocity_metric: "rbf"
13
+ metric_patience: 25
14
+ patience: 25
15
+ patience_geopath: 50
16
+ n_centers: 300
17
+ kappa: 2
18
+ rho: 0.001
19
+ alpha_metric: 1.0
20
+ metric_epochs: 100
21
+ branchsbm: true
22
+ seeds: [42]
23
+ branches: 5
24
+ metric_clusters: 2
25
+ batch_size: 256
dataloaders/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dataloaders/clonidine_single_branch.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
7
+ import pandas as pd
8
+ import numpy as np
9
+ from functools import partial
10
+ from scipy.spatial import cKDTree
11
+ from sklearn.cluster import KMeans
12
+ from torch.utils.data import TensorDataset
13
+
14
+
15
+ class ClonidineSingleBranchDataModule(pl.LightningDataModule):
16
+ def __init__(self, args):
17
+ super().__init__()
18
+ self.save_hyperparameters()
19
+
20
+ self.batch_size = args.batch_size
21
+ self.max_dim = args.dim
22
+ self.whiten = args.whiten
23
+ self.split_ratios = args.split_ratios
24
+
25
+ self.dim = args.dim
26
+ print("dimension")
27
+ print(self.dim)
28
+ # Path to your combined data
29
+ self.data_path = "./data/pca_and_leiden_labels.csv"
30
+ self.num_timesteps = 2
31
+ self.args = args
32
+ self._prepare_data()
33
+
34
+ def _prepare_data(self):
35
+ df = pd.read_csv(self.data_path, comment='#')
36
+ df = df.iloc[:, 1:]
37
+ df = df.replace('', np.nan)
38
+ pc_cols = df.columns[:self.dim]
39
+ for col in pc_cols:
40
+ df[col] = pd.to_numeric(df[col], errors='coerce')
41
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
42
+ leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
43
+
44
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
45
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
46
+
47
+ dmso_data = df[dmso_mask].copy()
48
+ clonidine_data = df[clonidine_mask].copy()
49
+
50
+ top_clonidine_clusters = ['0.0', '4.0']
51
+
52
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
53
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
54
+
55
+ x1_1_coords = x1_1_data[pc_cols].values
56
+ x1_2_coords = x1_2_data[pc_cols].values
57
+
58
+ x1_1_coords = x1_1_coords.astype(float)
59
+ x1_2_coords = x1_2_coords.astype(float)
60
+
61
+ # Target size is now the minimum across all three endpoint clusters
62
+ target_size = min(len(x1_1_coords), len(x1_2_coords),)
63
+
64
+ # Helper function to select points closest to centroid
65
+ def select_closest_to_centroid(coords, target_size):
66
+ if len(coords) <= target_size:
67
+ return coords
68
+
69
+ # Calculate centroid
70
+ centroid = np.mean(coords, axis=0)
71
+
72
+ # Calculate distances to centroid
73
+ distances = np.linalg.norm(coords - centroid, axis=1)
74
+
75
+ # Get indices of closest points
76
+ closest_indices = np.argsort(distances)[:target_size]
77
+
78
+ return coords[closest_indices]
79
+
80
+ # Sample all endpoint clusters to target size using centroid-based selection
81
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
82
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
83
+
84
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
85
+
86
+ # DMSO (unchanged)
87
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
88
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
89
+
90
+ dmso_coords = dmso_cluster_data[pc_cols].values
91
+
92
+ # Random sampling from largest DMSO cluster to match target size
93
+ # For DMSO, we'll also use centroid-based selection for consistency
94
+ if len(dmso_coords) >= target_size:
95
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
96
+ else:
97
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
98
+ remaining_needed = target_size - len(dmso_coords)
99
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
100
+ other_dmso_coords = other_dmso_data[pc_cols].values
101
+
102
+ if len(other_dmso_coords) >= remaining_needed:
103
+ # Select closest to centroid from other DMSO cells
104
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
105
+ x0_coords = np.vstack([dmso_coords, other_selected])
106
+ else:
107
+ # Use all available DMSO cells and reduce target size
108
+ all_dmso_coords = dmso_data[pc_cols].values
109
+ target_size = min(target_size, len(all_dmso_coords))
110
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
111
+
112
+ # Re-select endpoint clusters with updated target size
113
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
114
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
115
+
116
+ # No need to resample since we already selected the right number
117
+ # The endpoint clusters are already at target_size from centroid-based selection
118
+
119
+ self.n_samples = target_size
120
+
121
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
122
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
123
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
124
+ x1 = torch.cat([x1_1, x1_2], dim=0)
125
+
126
+ self.coords_t0 = x0
127
+ self.coords_t1 = x1
128
+ self.time_labels = np.concatenate([
129
+ np.zeros(len(self.coords_t0)), # t=0
130
+ np.ones(len(self.coords_t1)), # t=1
131
+ ])
132
+
133
+ split_index = int(target_size * self.split_ratios[0])
134
+
135
+ if target_size - split_index < self.batch_size:
136
+ split_index = target_size - self.batch_size
137
+ print('total count is:', target_size)
138
+
139
+ train_x0 = x0[:split_index]
140
+ val_x0 = x0[split_index:]
141
+ train_x1 = x1[:split_index]
142
+ val_x1 = x1[split_index:]
143
+
144
+
145
+ self.val_x0 = val_x0
146
+
147
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
148
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
149
+
150
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
151
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
152
+
153
+ # Updated train dataloaders to include x1_3
154
+ self.train_dataloaders = {
155
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
156
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
157
+ }
158
+
159
+ self.val_dataloaders = {
160
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
161
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
162
+ }
163
+
164
+ all_coords = df[pc_cols].dropna().values.astype(float)
165
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
166
+ self.tree = cKDTree(all_coords)
167
+
168
+ self.test_dataloaders = {
169
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
170
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
171
+ }
172
+
173
+ # Updated metric samples - now using 4 clusters instead of 3
174
+ #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
175
+ km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset.numpy())
176
+
177
+ cluster_labels = km_all.labels_
178
+
179
+ cluster_0_mask = cluster_labels == 0
180
+ cluster_1_mask = cluster_labels == 1
181
+
182
+ samples = self.dataset.cpu().numpy()
183
+
184
+ cluster_0_data = samples[cluster_0_mask]
185
+ cluster_1_data = samples[cluster_1_mask]
186
+
187
+ self.metric_samples_dataloaders = [
188
+ DataLoader(
189
+ torch.tensor(cluster_1_data, dtype=torch.float32),
190
+ batch_size=cluster_1_data.shape[0],
191
+ shuffle=False,
192
+ drop_last=False,
193
+ ),
194
+ DataLoader(
195
+ torch.tensor(cluster_0_data, dtype=torch.float32),
196
+ batch_size=cluster_0_data.shape[0],
197
+ shuffle=False,
198
+ drop_last=False,
199
+ ),
200
+ ]
201
+
202
+ def train_dataloader(self):
203
+ combined_loaders = {
204
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
205
+ "metric_samples": CombinedLoader(
206
+ self.metric_samples_dataloaders, mode="min_size"
207
+ ),
208
+ }
209
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
210
+
211
+ def val_dataloader(self):
212
+ combined_loaders = {
213
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
214
+ "metric_samples": CombinedLoader(
215
+ self.metric_samples_dataloaders, mode="min_size"
216
+ ),
217
+ }
218
+
219
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
220
+
221
+
222
+
223
+ def test_dataloader(self):
224
+ combined_loaders = {
225
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
226
+ "metric_samples": CombinedLoader(
227
+ self.metric_samples_dataloaders, mode="min_size"
228
+ ),
229
+ }
230
+
231
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
232
+
233
+ def get_manifold_proj(self, points):
234
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
235
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
236
+
237
+ @staticmethod
238
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
239
+ """
240
+ Apply local smoothing based on k-nearest neighbors in the full dataset
241
+ This replaces the plane projection for 2D manifold regularization
242
+ """
243
+ points_np = x.detach().cpu().numpy()
244
+ _, idx = tree.query(points_np, k=k)
245
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
246
+
247
+ # Compute weighted average of neighbors
248
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
249
+ weights = torch.exp(-dists / temp)
250
+ weights = weights / weights.sum(dim=1, keepdim=True)
251
+
252
+ # Weighted average of neighbors
253
+ smoothed = (weights * nearest_pts).sum(dim=1)
254
+
255
+ # Blend original point with smoothed version
256
+ alpha = 0.3 # How much smoothing to apply
257
+ return (1 - alpha) * x + alpha * smoothed
258
+
259
+ def get_timepoint_data(self):
260
+ """Return data organized by timepoints for visualization"""
261
+ return {
262
+ 't0': self.coords_t0,
263
+ 't1': self.coords_t1,
264
+ 'time_labels': self.time_labels
265
+ }
dataloaders/clonidine_v2_data.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
7
+ import pandas as pd
8
+ import numpy as np
9
+ from functools import partial
10
+ from scipy.spatial import cKDTree
11
+ from sklearn.cluster import KMeans
12
+ from torch.utils.data import TensorDataset
13
+
14
+
15
+ class ClonidineV2DataModule(pl.LightningDataModule):
16
+ def __init__(self, args):
17
+ super().__init__()
18
+ self.save_hyperparameters()
19
+
20
+ self.batch_size = args.batch_size
21
+ self.max_dim = args.dim
22
+ self.whiten = args.whiten
23
+ self.split_ratios = args.split_ratios
24
+
25
+ self.dim = args.dim
26
+ print("dimension")
27
+ print(self.dim)
28
+ # Path to your combined data
29
+ self.data_path = "./data/pca_and_leiden_labels.csv"
30
+ self.num_timesteps = 2
31
+ self.args = args
32
+ self._prepare_data()
33
+
34
+ def _prepare_data(self):
35
+ df = pd.read_csv(self.data_path, comment='#')
36
+ df = df.iloc[:, 1:]
37
+ df = df.replace('', np.nan)
38
+ pc_cols = df.columns[:self.dim]
39
+ for col in pc_cols:
40
+ df[col] = pd.to_numeric(df[col], errors='coerce')
41
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
42
+ leiden_clonidine_col = 'leiden_Clonidine (hydrochloride)_5.0uM'
43
+
44
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
45
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
46
+
47
+ dmso_data = df[dmso_mask].copy()
48
+ clonidine_data = df[clonidine_mask].copy()
49
+
50
+ top_clonidine_clusters = ['0.0', '4.0']
51
+
52
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
53
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
54
+
55
+ x1_1_coords = x1_1_data[pc_cols].values
56
+ x1_2_coords = x1_2_data[pc_cols].values
57
+
58
+ x1_1_coords = x1_1_coords.astype(float)
59
+ x1_2_coords = x1_2_coords.astype(float)
60
+
61
+ # Target size is now the minimum across all three endpoint clusters
62
+ target_size = min(len(x1_1_coords), len(x1_2_coords),)
63
+
64
+ # Helper function to select points closest to centroid
65
+ def select_closest_to_centroid(coords, target_size):
66
+ if len(coords) <= target_size:
67
+ return coords
68
+
69
+ # Calculate centroid
70
+ centroid = np.mean(coords, axis=0)
71
+
72
+ # Calculate distances to centroid
73
+ distances = np.linalg.norm(coords - centroid, axis=1)
74
+
75
+ # Get indices of closest points
76
+ closest_indices = np.argsort(distances)[:target_size]
77
+
78
+ return coords[closest_indices]
79
+
80
+ # Sample all endpoint clusters to target size using centroid-based selection
81
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
82
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
83
+
84
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
85
+
86
+ # DMSO (unchanged)
87
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
88
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
89
+
90
+ dmso_coords = dmso_cluster_data[pc_cols].values
91
+
92
+ # Random sampling from largest DMSO cluster to match target size
93
+ # For DMSO, we'll also use centroid-based selection for consistency
94
+ if len(dmso_coords) >= target_size:
95
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
96
+ else:
97
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
98
+ remaining_needed = target_size - len(dmso_coords)
99
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
100
+ other_dmso_coords = other_dmso_data[pc_cols].values
101
+
102
+ if len(other_dmso_coords) >= remaining_needed:
103
+ # Select closest to centroid from other DMSO cells
104
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
105
+ x0_coords = np.vstack([dmso_coords, other_selected])
106
+ else:
107
+ # Use all available DMSO cells and reduce target size
108
+ all_dmso_coords = dmso_data[pc_cols].values
109
+ target_size = min(target_size, len(all_dmso_coords))
110
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
111
+
112
+ # Re-select endpoint clusters with updated target size
113
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
114
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
115
+
116
+ # No need to resample since we already selected the right number
117
+ # The endpoint clusters are already at target_size from centroid-based selection
118
+
119
+ self.n_samples = target_size
120
+
121
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
122
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
123
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
124
+
125
+ self.coords_t0 = x0
126
+ self.coords_t1_1 = x1_1
127
+ self.coords_t1_2 = x1_2
128
+ self.time_labels = np.concatenate([
129
+ np.zeros(len(self.coords_t0)), # t=0
130
+ np.ones(len(self.coords_t1_1)), # t=1
131
+ np.ones(len(self.coords_t1_2)),
132
+ ])
133
+
134
+ split_index = int(target_size * self.split_ratios[0])
135
+
136
+ if target_size - split_index < self.batch_size:
137
+ split_index = target_size - self.batch_size
138
+ print('total count is:', target_size)
139
+
140
+ train_x0 = x0[:split_index]
141
+ val_x0 = x0[split_index:]
142
+ train_x1_1 = x1_1[:split_index]
143
+ val_x1_1 = x1_1[split_index:]
144
+ train_x1_2 = x1_2[:split_index]
145
+ val_x1_2 = x1_2[split_index:]
146
+
147
+
148
+ self.val_x0 = val_x0
149
+
150
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
151
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
152
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
153
+
154
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
155
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
156
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
157
+
158
+ # Updated train dataloaders to include x1_3
159
+ self.train_dataloaders = {
160
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
161
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
162
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
163
+ }
164
+
165
+ self.val_dataloaders = {
166
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
167
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
168
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
169
+ }
170
+
171
+ all_coords = df[pc_cols].dropna().values.astype(float)
172
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
173
+ self.tree = cKDTree(all_coords)
174
+
175
+ self.test_dataloaders = {
176
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
177
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
178
+ }
179
+
180
+ km_all = KMeans(n_clusters=3, random_state=0).fit(self.dataset.numpy())
181
+
182
+ cluster_labels = km_all.labels_
183
+
184
+ cluster_0_mask = cluster_labels == 0
185
+ cluster_1_mask = cluster_labels == 1
186
+ cluster_2_mask = cluster_labels == 2
187
+
188
+ samples = self.dataset.cpu().numpy()
189
+
190
+ cluster_0_data = samples[cluster_0_mask]
191
+ cluster_1_data = samples[cluster_1_mask]
192
+ cluster_2_data = samples[cluster_2_mask]
193
+
194
+ self.metric_samples_dataloaders = [
195
+ DataLoader(
196
+ torch.tensor(cluster_2_data, dtype=torch.float32),
197
+ batch_size=cluster_2_data.shape[0],
198
+ shuffle=False,
199
+ drop_last=False,
200
+ ),
201
+ DataLoader(
202
+ torch.tensor(cluster_0_data, dtype=torch.float32),
203
+ batch_size=cluster_0_data.shape[0],
204
+ shuffle=False,
205
+ drop_last=False,
206
+ ),
207
+
208
+ DataLoader(
209
+ torch.tensor(cluster_1_data, dtype=torch.float32),
210
+ batch_size=cluster_1_data.shape[0],
211
+ shuffle=False,
212
+ drop_last=False,
213
+ ),
214
+ ]
215
+
216
+ def train_dataloader(self):
217
+ combined_loaders = {
218
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
219
+ "metric_samples": CombinedLoader(
220
+ self.metric_samples_dataloaders, mode="min_size"
221
+ ),
222
+ }
223
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
224
+
225
+ def val_dataloader(self):
226
+ combined_loaders = {
227
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
228
+ "metric_samples": CombinedLoader(
229
+ self.metric_samples_dataloaders, mode="min_size"
230
+ ),
231
+ }
232
+
233
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
234
+
235
+
236
+ def test_dataloader(self):
237
+ combined_loaders = {
238
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
239
+ "metric_samples": CombinedLoader(
240
+ self.metric_samples_dataloaders, mode="min_size"
241
+ ),
242
+ }
243
+
244
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
245
+
246
+ def get_manifold_proj(self, points):
247
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
248
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
249
+
250
+ @staticmethod
251
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
252
+ """
253
+ Apply local smoothing based on k-nearest neighbors in the full dataset
254
+ This replaces the plane projection for 2D manifold regularization
255
+ """
256
+ points_np = x.detach().cpu().numpy()
257
+ _, idx = tree.query(points_np, k=k)
258
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
259
+
260
+ # Compute weighted average of neighbors
261
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
262
+ weights = torch.exp(-dists / temp)
263
+ weights = weights / weights.sum(dim=1, keepdim=True)
264
+
265
+ # Weighted average of neighbors
266
+ smoothed = (weights * nearest_pts).sum(dim=1)
267
+
268
+ # Blend original point with smoothed version
269
+ alpha = 0.3 # How much smoothing to apply
270
+ return (1 - alpha) * x + alpha * smoothed
271
+
272
+ def get_timepoint_data(self):
273
+ """Return data organized by timepoints for visualization"""
274
+ return {
275
+ 't0': self.coords_t0,
276
+ 't1_1': self.coords_t1_1,
277
+ 't1_2': self.coords_t1_2,
278
+ 'time_labels': self.time_labels
279
+ }
280
+
dataloaders/lidar_data.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from pytorch_lightning.utilities.combined_loader import CombinedLoader
7
+ import laspy
8
+ import numpy as np
9
+ from scipy.spatial import cKDTree
10
+ import math
11
+ from functools import partial
12
+ from torch.utils.data import TensorDataset
13
+
14
+
15
+ class GaussianMM:
16
+ def __init__(self, mu, var):
17
+ super().__init__()
18
+ self.centers = torch.tensor(mu)
19
+ self.logstd = torch.tensor(var).log() / 2.0
20
+ self.K = self.centers.shape[0]
21
+
22
+ def logprob(self, x):
23
+ logprobs = self.normal_logprob(
24
+ x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
25
+ )
26
+ logprobs = torch.sum(logprobs, dim=2)
27
+ return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
28
+
29
+ def normal_logprob(self, z, mean, log_std):
30
+ mean = mean + torch.tensor(0.0)
31
+ log_std = log_std + torch.tensor(0.0)
32
+ c = torch.tensor([math.log(2 * math.pi)]).to(z)
33
+ inv_sigma = torch.exp(-log_std)
34
+ tmp = (z - mean) * inv_sigma
35
+ return -0.5 * (tmp * tmp + 2 * log_std + c)
36
+
37
+ def __call__(self, n_samples):
38
+ idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
39
+ mean = self.centers[idx]
40
+ return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
41
+
42
+ class BranchedLidarDataModule(pl.LightningDataModule):
43
+ def __init__(self, args):
44
+ super().__init__()
45
+ self.save_hyperparameters()
46
+
47
+ self.data_path = args.data_path
48
+ self.batch_size = args.batch_size
49
+ self.max_dim = args.dim
50
+ self.whiten = args.whiten
51
+ self.p0_mu = [
52
+ [-4.5, -4.0, 0.5],
53
+ [-4.2, -3.5, 0.5],
54
+ [-4.0, -3.0, 0.5],
55
+ [-3.75, -2.5, 0.5],
56
+ ]
57
+ self.p0_var = 0.02
58
+
59
+ self.p1_1_mu = [
60
+ [-2.5, -0.25, 0.5],
61
+ [-2.25, 0.675, 0.5],
62
+ [-2, 1.5, 0.5],
63
+ ]
64
+ self.p1_2_mu = [
65
+ [2, -2, 0.5],
66
+ [2.6, -1.25, 0.5],
67
+ [3.2, -0.5, 0.5]
68
+ ]
69
+
70
+ self.p1_var = 0.03
71
+ self.k = 20
72
+ self.n_samples = 5000
73
+ self.num_timesteps = 2
74
+ self.split_ratios = args.split_ratios
75
+ self._prepare_data()
76
+
77
+ def assign_region(self):
78
+ all_centers = {
79
+ 0: torch.tensor(self.p0_mu), # Region 0: p0
80
+ 1: torch.tensor(self.p1_1_mu), # Region 1: p1_1
81
+ 2: torch.tensor(self.p1_2_mu), # Region 2: p1_2
82
+ }
83
+
84
+ dataset = self.dataset.to(torch.float32)
85
+ N = dataset.shape[0]
86
+ assignments = torch.zeros(N, dtype=torch.long)
87
+
88
+ # For each point, compute min distance to each region's centers
89
+ for i in range(N):
90
+ point = dataset[i]
91
+ min_dist = float("inf")
92
+ best_region = 0
93
+ for region, centers in all_centers.items():
94
+ dists = ((centers - point)**2).sum(dim=1)
95
+ region_min = dists.min()
96
+ if region_min < min_dist:
97
+ min_dist = region_min
98
+ best_region = region
99
+ assignments[i] = best_region
100
+ return assignments
101
+
102
+ def _prepare_data(self):
103
+ las = laspy.read(self.data_path)
104
+ # Extract only "ground" points.
105
+ self.mask = las.classification == 2
106
+ # Original Preprocessing
107
+ x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
108
+ y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
109
+ z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
110
+ dataset = np.vstack(
111
+ (
112
+ las.X[self.mask] * x_scale + x_offset,
113
+ las.Y[self.mask] * y_scale + y_offset,
114
+ las.Z[self.mask] * z_scale + z_offset,
115
+ )
116
+ ).transpose()
117
+ mi = dataset.min(axis=0, keepdims=True)
118
+ ma = dataset.max(axis=0, keepdims=True)
119
+ dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
120
+
121
+ self.dataset = torch.tensor(dataset, dtype=torch.float32)
122
+ self.tree = cKDTree(dataset)
123
+
124
+ x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
125
+ x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
126
+ x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
127
+
128
+ x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
129
+ x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
130
+ x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
131
+
132
+ split_index = int(self.n_samples * self.split_ratios[0])
133
+
134
+ self.scaler = StandardScaler()
135
+ if self.whiten:
136
+ self.dataset = torch.tensor(
137
+ self.scaler.fit_transform(dataset), dtype=torch.float32
138
+ )
139
+ x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
140
+ x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
141
+ x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
142
+
143
+ train_x0 = x0[:split_index]
144
+ val_x0 = x0[split_index:]
145
+
146
+ # branches
147
+ train_x1_1 = x1_1[:split_index]
148
+ print("train_x1_1")
149
+ print(train_x1_1.shape)
150
+ val_x1_1 = x1_1[split_index:]
151
+ train_x1_2 = x1_2[:split_index]
152
+ val_x1_2 = x1_2[split_index:]
153
+
154
+ self.val_x0 = val_x0
155
+
156
+ # Adjust split_index to ensure minimum validation samples
157
+ if self.n_samples - split_index < self.batch_size:
158
+ split_index = self.n_samples - self.batch_size
159
+
160
+ self.train_dataloaders = {
161
+ "x0": DataLoader(train_x0, batch_size=self.batch_size, shuffle=True, drop_last=True),
162
+ "x1_1": DataLoader(train_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
163
+ "x1_2": DataLoader(train_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
164
+ }
165
+ self.val_dataloaders = {
166
+ "x0": DataLoader(val_x0, batch_size=self.batch_size, shuffle=False, drop_last=True),
167
+ "x1_1": DataLoader(val_x1_1, batch_size=self.batch_size, shuffle=True, drop_last=True),
168
+ "x1_2": DataLoader(val_x1_2, batch_size=self.batch_size, shuffle=True, drop_last=True),
169
+ }
170
+ # to edit?
171
+ self.test_dataloaders = [
172
+ DataLoader(
173
+ self.val_x0,
174
+ batch_size=self.val_x0.shape[0],
175
+ shuffle=False,
176
+ drop_last=False,
177
+ ),
178
+ DataLoader(
179
+ self.dataset,
180
+ batch_size=self.dataset.shape[0],
181
+ shuffle=False,
182
+ drop_last=False,
183
+ ),
184
+ ]
185
+
186
+ points = self.dataset.cpu().numpy()
187
+ x, y = points[:, 0], points[:, 1]
188
+ # Diagonal-based coordinates (rotated 45°)
189
+ u = (x + y) / np.sqrt(2) # along x=y
190
+ # start region (A) using u
191
+ u_thresh = np.percentile(u, 30) # tweak this threshold to control size
192
+ mask_A = u <= u_thresh
193
+
194
+ # among the rest, split by x=y diagonal
195
+ remaining = ~mask_A
196
+ mask_B = remaining & (x < y) # left of diagonal
197
+ mask_C = remaining & (x >= y) # right of diagonal
198
+
199
+ # Assign dataloaders
200
+ self.metric_samples_dataloaders = [
201
+ DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
202
+ DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
203
+ DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
204
+ ]
205
+
206
+ def train_dataloader(self):
207
+ combined_loaders = {
208
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
209
+ "metric_samples": CombinedLoader(
210
+ self.metric_samples_dataloaders, mode="min_size"
211
+ ),
212
+ }
213
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
214
+
215
+ def val_dataloader(self):
216
+ combined_loaders = {
217
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
218
+ "metric_samples": CombinedLoader(
219
+ self.metric_samples_dataloaders, mode="min_size"
220
+ ),
221
+ }
222
+
223
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
224
+
225
+ def test_dataloader(self):
226
+ return CombinedLoader(self.test_dataloaders)
227
+
228
+ def get_tangent_proj(self, points):
229
+ w = self.get_tangent_plane(points)
230
+ return partial(BranchedLidarDataModule.projection_op, w=w)
231
+
232
+ def get_tangent_plane(self, points, temp=1e-3):
233
+ points_np = points.detach().cpu().numpy()
234
+ _, idx = self.tree.query(points_np, k=self.k)
235
+ nearest_pts = self.dataset[idx]
236
+ nearest_pts = torch.tensor(nearest_pts).to(points)
237
+
238
+ dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
239
+ weights = torch.exp(-dists / temp)
240
+
241
+ # Fits plane with least vertical distance.
242
+ w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
243
+ return w
244
+
245
+ @staticmethod
246
+ def fit_plane(points, weights=None):
247
+ """Expects points to be of shape (..., 3).
248
+ Returns [a, b, c] such that the plane is defined as
249
+ ax + by + c = z
250
+ """
251
+ D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
252
+ z = points[..., 2]
253
+ if weights is not None:
254
+ Dtrans = D.transpose(-1, -2)
255
+ else:
256
+ DW = D * weights
257
+ Dtrans = DW.transpose(-1, -2)
258
+ w = torch.linalg.solve(
259
+ torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
260
+ ).squeeze(-1)
261
+ return w
262
+
263
+ @staticmethod
264
+ def projection_op(x, w):
265
+ """Projects points to a plane defined by w."""
266
+ # Normal vector to the tangent plane.
267
+ n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
268
+
269
+ pn = torch.sum(x * n, dim=-1, keepdim=True)
270
+ nn = torch.sum(n * n, dim=-1, keepdim=True)
271
+
272
+ # Offset.
273
+ d = w[..., 2:3]
274
+
275
+ # Projection of x onto n.
276
+ projn_x = ((pn + d) / nn) * n
277
+
278
+ # Remove component in the normal direction.
279
+ return x - projn_x
280
+
281
+ class WeightedBranchedLidarDataModule(pl.LightningDataModule):
282
+ def __init__(self, args):
283
+ super().__init__()
284
+ self.save_hyperparameters()
285
+
286
+ self.data_path = args.data_path
287
+ self.batch_size = args.batch_size
288
+ self.max_dim = args.dim
289
+ self.whiten = args.whiten
290
+ self.p0_mu = [
291
+ [-4.5, -4.0, 0.5],
292
+ [-4.2, -3.5, 0.5],
293
+ [-4.0, -3.0, 0.5],
294
+ [-3.75, -2.5, 0.5],
295
+ ]
296
+ self.p0_var = 0.02
297
+ # multiple p1 for each branch
298
+ #changed
299
+ self.p1_1_mu = [
300
+ [-2.5, -0.25, 0.5],
301
+ [-2.25, 0.675, 0.5],
302
+ [-2, 1.5, 0.5],
303
+ ]
304
+ self.p1_2_mu = [
305
+ [2, -2, 0.5],
306
+ [2.6, -1.25, 0.5],
307
+ [3.2, -0.5, 0.5]
308
+ ]
309
+
310
+ self.p1_var = 0.03
311
+ self.k = 20
312
+ self.n_samples = 5000
313
+ self.num_timesteps = 2
314
+ self.split_ratios = args.split_ratios
315
+
316
+ self.num_timesteps = 2
317
+ self.metric_clusters = 3
318
+ self.args = args
319
+ self._prepare_data()
320
+
321
+ def _prepare_data(self):
322
+ las = laspy.read(self.data_path)
323
+ # Extract only "ground" points.
324
+ self.mask = las.classification == 2
325
+ # Original Preprocessing
326
+ x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
327
+ y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
328
+ z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
329
+ dataset = np.vstack(
330
+ (
331
+ las.X[self.mask] * x_scale + x_offset,
332
+ las.Y[self.mask] * y_scale + y_offset,
333
+ las.Z[self.mask] * z_scale + z_offset,
334
+ )
335
+ ).transpose()
336
+ mi = dataset.min(axis=0, keepdims=True)
337
+ ma = dataset.max(axis=0, keepdims=True)
338
+ dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
339
+
340
+ self.dataset = torch.tensor(dataset, dtype=torch.float32)
341
+ self.tree = cKDTree(dataset)
342
+
343
+ x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
344
+ x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
345
+ x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
346
+
347
+ x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
348
+ x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
349
+ x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
350
+
351
+ split_index = int(self.n_samples * self.split_ratios[0])
352
+
353
+ self.scaler = StandardScaler()
354
+ if self.whiten:
355
+ self.dataset = torch.tensor(
356
+ self.scaler.fit_transform(dataset), dtype=torch.float32
357
+ )
358
+ x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
359
+ x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
360
+ x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
361
+
362
+ self.coords_t0 = x0
363
+ self.coords_t1_1 = x1_1
364
+ self.coords_t1_2 = x1_2
365
+ self.time_labels = np.concatenate([
366
+ np.zeros(len(self.coords_t0)), # t=0
367
+ np.ones(len(self.coords_t1_1)), # t=1
368
+ np.ones(len(self.coords_t1_2)), # t=1
369
+ ])
370
+
371
+ train_x0 = x0[:split_index]
372
+ val_x0 = x0[split_index:]
373
+
374
+ # branches
375
+ train_x1_1 = x1_1[:split_index]
376
+
377
+ val_x1_1 = x1_1[split_index:]
378
+ train_x1_2 = x1_2[:split_index]
379
+ val_x1_2 = x1_2[split_index:]
380
+
381
+ self.val_x0 = val_x0
382
+
383
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
384
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
385
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
386
+
387
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
388
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
389
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
390
+
391
+ # Adjust split_index to ensure minimum validation samples
392
+ if self.n_samples - split_index < self.batch_size:
393
+ split_index = self.n_samples - self.batch_size
394
+
395
+ self.train_dataloaders = {
396
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
397
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
398
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
399
+ }
400
+
401
+ self.val_dataloaders = {
402
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
403
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
404
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
405
+ }
406
+
407
+ # to edit?
408
+ self.test_dataloaders = {
409
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
410
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
411
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
412
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
413
+ }
414
+
415
+ points = self.dataset.cpu().numpy()
416
+ x, y = points[:, 0], points[:, 1]
417
+ # Diagonal-based coordinates (rotated 45°)
418
+ u = (x + y) / np.sqrt(2) # along x=y
419
+ # start region (A) using u
420
+ u_thresh = np.percentile(u, 30) # tweak this threshold to control size
421
+ mask_A = u <= u_thresh
422
+
423
+ # among the rest, split by x=y diagonal
424
+ remaining = ~mask_A
425
+ mask_B = remaining & (x < y) # left of diagonal
426
+ mask_C = remaining & (x >= y) # right of diagonal
427
+
428
+ # Assign dataloaders
429
+ self.metric_samples_dataloaders = [
430
+ DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
431
+ DataLoader(torch.tensor(points[mask_B], dtype=torch.float32), batch_size=points[mask_B].shape[0], shuffle=False),
432
+ DataLoader(torch.tensor(points[mask_C], dtype=torch.float32), batch_size=points[mask_C].shape[0], shuffle=False),
433
+ ]
434
+
435
+ def train_dataloader(self):
436
+ combined_loaders = {
437
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
438
+ "metric_samples": CombinedLoader(
439
+ self.metric_samples_dataloaders, mode="min_size"
440
+ ),
441
+ }
442
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
443
+
444
+ def val_dataloader(self):
445
+ combined_loaders = {
446
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
447
+ "metric_samples": CombinedLoader(
448
+ self.metric_samples_dataloaders, mode="min_size"
449
+ ),
450
+ }
451
+
452
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
453
+
454
+ def test_dataloader(self):
455
+ combined_loaders = {
456
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
457
+ "metric_samples": CombinedLoader(
458
+ self.metric_samples_dataloaders, mode="min_size"
459
+ ),
460
+ }
461
+
462
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
463
+
464
+ def get_tangent_proj(self, points):
465
+ w = self.get_tangent_plane(points)
466
+ return partial(BranchedLidarDataModule.projection_op, w=w)
467
+
468
+ def get_tangent_plane(self, points, temp=1e-3):
469
+ points_np = points.detach().cpu().numpy()
470
+ _, idx = self.tree.query(points_np, k=self.k)
471
+ nearest_pts = self.dataset[idx]
472
+ nearest_pts = torch.tensor(nearest_pts).to(points)
473
+
474
+ dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
475
+ weights = torch.exp(-dists / temp)
476
+
477
+ # Fits plane with least vertical distance.
478
+ w = BranchedLidarDataModule.fit_plane(nearest_pts, weights)
479
+ return w
480
+
481
+ @staticmethod
482
+ def fit_plane(points, weights=None):
483
+ """Expects points to be of shape (..., 3).
484
+ Returns [a, b, c] such that the plane is defined as
485
+ ax + by + c = z
486
+ """
487
+ D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
488
+ z = points[..., 2]
489
+ if weights is not None:
490
+ Dtrans = D.transpose(-1, -2)
491
+ else:
492
+ DW = D * weights
493
+ Dtrans = DW.transpose(-1, -2)
494
+ w = torch.linalg.solve(
495
+ torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
496
+ ).squeeze(-1)
497
+ return w
498
+
499
+ @staticmethod
500
+ def projection_op(x, w):
501
+ """Projects points to a plane defined by w."""
502
+ # Normal vector to the tangent plane.
503
+ n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
504
+
505
+ pn = torch.sum(x * n, dim=-1, keepdim=True)
506
+ nn = torch.sum(n * n, dim=-1, keepdim=True)
507
+
508
+ # Offset.
509
+ d = w[..., 2:3]
510
+
511
+ # Projection of x onto n.
512
+ projn_x = ((pn + d) / nn) * n
513
+
514
+ # Remove component in the normal direction.
515
+ return x - projn_x
516
+
517
+ def get_timepoint_data(self):
518
+ """Return data organized by timepoints for visualization"""
519
+ return {
520
+ 't0': self.coords_t0,
521
+ 't1_1': self.coords_t1_1,
522
+ 't1_2': self.coords_t1_2,
523
+ 'time_labels': self.time_labels
524
+ }
525
+
526
+ def get_datamodule():
527
+ datamodule = WeightedBranchedLidarDataModule(args)
528
+ datamodule.setup(stage="fit")
529
+ return datamodule
dataloaders/lidar_data_single.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from pytorch_lightning.utilities.combined_loader import CombinedLoader
7
+ import laspy
8
+ import numpy as np
9
+ from scipy.spatial import cKDTree
10
+ import math
11
+ from functools import partial
12
+ from torch.utils.data import TensorDataset
13
+
14
+
15
+ class GaussianMM:
16
+ def __init__(self, mu, var):
17
+ super().__init__()
18
+ self.centers = torch.tensor(mu)
19
+ self.logstd = torch.tensor(var).log() / 2.0
20
+ self.K = self.centers.shape[0]
21
+
22
+ def logprob(self, x):
23
+ logprobs = self.normal_logprob(
24
+ x.unsqueeze(1), self.centers.unsqueeze(0), self.logstd
25
+ )
26
+ logprobs = torch.sum(logprobs, dim=2)
27
+ return torch.logsumexp(logprobs, dim=1) - math.log(self.K)
28
+
29
+ def normal_logprob(self, z, mean, log_std):
30
+ mean = mean + torch.tensor(0.0)
31
+ log_std = log_std + torch.tensor(0.0)
32
+ c = torch.tensor([math.log(2 * math.pi)]).to(z)
33
+ inv_sigma = torch.exp(-log_std)
34
+ tmp = (z - mean) * inv_sigma
35
+ return -0.5 * (tmp * tmp + 2 * log_std + c)
36
+
37
+ def __call__(self, n_samples):
38
+ idx = torch.randint(self.K, (n_samples,)).to(self.centers.device)
39
+ mean = self.centers[idx]
40
+ return torch.randn(*mean.shape).to(mean) * torch.exp(self.logstd) + mean
41
+
42
+ class LidarSingleDataModule(pl.LightningDataModule):
43
+ def __init__(self, args):
44
+ super().__init__()
45
+ self.save_hyperparameters()
46
+
47
+ self.data_path = args.data_path
48
+ self.batch_size = args.batch_size
49
+ self.max_dim = args.dim
50
+ self.whiten = args.whiten
51
+ self.p0_mu = [
52
+ [-4.5, -4.0, 0.5],
53
+ [-4.2, -3.5, 0.5],
54
+ [-4.0, -3.0, 0.5],
55
+ [-3.75, -2.5, 0.5],
56
+ ]
57
+ self.p0_var = 0.02
58
+ # multiple p1 for each branch
59
+ #changed
60
+ self.p1_1_mu = [
61
+ [-2.5, -0.25, 0.5],
62
+ [-2.25, 0.675, 0.5],
63
+ [-2, 1.5, 0.5],
64
+ ]
65
+ self.p1_2_mu = [
66
+ [2, -2, 0.5],
67
+ [2.6, -1.25, 0.5],
68
+ [3.2, -0.5, 0.5]
69
+ ]
70
+
71
+ self.p1_var = 0.03
72
+ self.k = 20
73
+ self.n_samples = 5000
74
+ self.num_timesteps = 2
75
+ self.split_ratios = args.split_ratios
76
+
77
+ self.num_timesteps = 2
78
+ self.metric_clusters = 3
79
+ self.args = args
80
+ self._prepare_data()
81
+
82
+ def _prepare_data(self):
83
+ las = laspy.read(self.data_path)
84
+ # Extract only "ground" points.
85
+ self.mask = las.classification == 2
86
+ # Original Preprocessing
87
+ x_offset, x_scale = las.header.offsets[0], las.header.scales[0]
88
+ y_offset, y_scale = las.header.offsets[1], las.header.scales[1]
89
+ z_offset, z_scale = las.header.offsets[2], las.header.scales[2]
90
+ dataset = np.vstack(
91
+ (
92
+ las.X[self.mask] * x_scale + x_offset,
93
+ las.Y[self.mask] * y_scale + y_offset,
94
+ las.Z[self.mask] * z_scale + z_offset,
95
+ )
96
+ ).transpose()
97
+ mi = dataset.min(axis=0, keepdims=True)
98
+ ma = dataset.max(axis=0, keepdims=True)
99
+ dataset = (dataset - mi) / (ma - mi) * [10.0, 10.0, 2.0] + [-5.0, -5.0, 0.0]
100
+
101
+ self.dataset = torch.tensor(dataset, dtype=torch.float32)
102
+ self.tree = cKDTree(dataset)
103
+
104
+ x0_gaussian = GaussianMM(self.p0_mu, self.p0_var)(self.n_samples)
105
+ x1_1_gaussian = GaussianMM(self.p1_1_mu, self.p1_var)(self.n_samples)
106
+ x1_2_gaussian = GaussianMM(self.p1_2_mu, self.p1_var)(self.n_samples)
107
+
108
+ x0 = self.get_tangent_proj(x0_gaussian)(x0_gaussian)
109
+ x1_1 = self.get_tangent_proj(x1_1_gaussian)(x1_1_gaussian)
110
+ x1_2 = self.get_tangent_proj(x1_2_gaussian)(x1_2_gaussian)
111
+
112
+ split_index = int(self.n_samples * self.split_ratios[0])
113
+
114
+ self.scaler = StandardScaler()
115
+ if self.whiten:
116
+ self.dataset = torch.tensor(
117
+ self.scaler.fit_transform(dataset), dtype=torch.float32
118
+ )
119
+ x0 = torch.tensor(self.scaler.transform(x0), dtype=torch.float32)
120
+ x1_1 = torch.tensor(self.scaler.transform(x1_1), dtype=torch.float32)
121
+ x1_2 = torch.tensor(self.scaler.transform(x1_2), dtype=torch.float32)
122
+ x1 = torch.cat([x1_1, x1_2], dim=0)
123
+
124
+ self.coords_t0 = x0
125
+ self.coords_t1 = x1
126
+ self.time_labels = np.concatenate([
127
+ np.zeros(len(self.coords_t0)), # t=0
128
+ np.ones(len(self.coords_t1)), # t=1
129
+ ])
130
+
131
+ train_x0 = x0[:split_index]
132
+ val_x0 = x0[split_index:]
133
+
134
+ # branches
135
+ train_x1 = x1[:split_index]
136
+ val_x1 = x1[split_index:]
137
+
138
+ self.val_x0 = val_x0
139
+
140
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
141
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
142
+
143
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
144
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
145
+
146
+ # Adjust split_index to ensure minimum validation samples
147
+ if self.n_samples - split_index < self.batch_size:
148
+ split_index = self.n_samples - self.batch_size
149
+
150
+ self.train_dataloaders = {
151
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
152
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
153
+ }
154
+
155
+ self.val_dataloaders = {
156
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
157
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
158
+ }
159
+
160
+ # to edit?
161
+ self.test_dataloaders = {
162
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=False),
163
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.val_x0.shape[0], shuffle=True, drop_last=True),
164
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=True, drop_last=False),
165
+ }
166
+
167
+ points = self.dataset.cpu().numpy()
168
+ x, y = points[:, 0], points[:, 1]
169
+ # Diagonal-based coordinates (rotated 45°)
170
+ u = (x + y) / np.sqrt(2) # along x=y
171
+ # start region (A) using u
172
+ u_thresh = np.percentile(u, 30) # tweak this threshold to control size
173
+ mask_A = u <= u_thresh
174
+
175
+ # among the rest, split by x=y diagonal
176
+ remaining = ~mask_A
177
+ mask_B = remaining & (x < y) # left of diagonal
178
+ mask_C = remaining & (x >= y) # right of diagonal
179
+
180
+ # Assign dataloaders
181
+ self.metric_samples_dataloaders = [
182
+ DataLoader(torch.tensor(points[mask_A], dtype=torch.float32), batch_size=points[mask_A].shape[0], shuffle=False),
183
+ DataLoader(torch.tensor(points[remaining], dtype=torch.float32), batch_size=points[remaining].shape[0], shuffle=False),
184
+ ]
185
+
186
+ def train_dataloader(self):
187
+ combined_loaders = {
188
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
189
+ "metric_samples": CombinedLoader(
190
+ self.metric_samples_dataloaders, mode="min_size"
191
+ ),
192
+ }
193
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
194
+
195
+ def val_dataloader(self):
196
+ combined_loaders = {
197
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
198
+ "metric_samples": CombinedLoader(
199
+ self.metric_samples_dataloaders, mode="min_size"
200
+ ),
201
+ }
202
+
203
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
204
+
205
+ def test_dataloader(self):
206
+ combined_loaders = {
207
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
208
+ "metric_samples": CombinedLoader(
209
+ self.metric_samples_dataloaders, mode="min_size"
210
+ ),
211
+ }
212
+
213
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
214
+
215
+ def get_tangent_proj(self, points):
216
+ w = self.get_tangent_plane(points)
217
+ return partial(LidarSingleDataModule.projection_op, w=w)
218
+
219
+ def get_tangent_plane(self, points, temp=1e-3):
220
+ points_np = points.detach().cpu().numpy()
221
+ _, idx = self.tree.query(points_np, k=self.k)
222
+ nearest_pts = self.dataset[idx]
223
+ nearest_pts = torch.tensor(nearest_pts).to(points)
224
+
225
+ dists = (points.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
226
+ weights = torch.exp(-dists / temp)
227
+
228
+ # Fits plane with least vertical distance.
229
+ w = LidarSingleDataModule.fit_plane(nearest_pts, weights)
230
+ return w
231
+
232
+ @staticmethod
233
+ def fit_plane(points, weights=None):
234
+ """Expects points to be of shape (..., 3).
235
+ Returns [a, b, c] such that the plane is defined as
236
+ ax + by + c = z
237
+ """
238
+ D = torch.cat([points[..., :2], torch.ones_like(points[..., 2:3])], dim=-1)
239
+ z = points[..., 2]
240
+ if weights is not None:
241
+ Dtrans = D.transpose(-1, -2)
242
+ else:
243
+ DW = D * weights
244
+ Dtrans = DW.transpose(-1, -2)
245
+ w = torch.linalg.solve(
246
+ torch.matmul(Dtrans, D), torch.matmul(Dtrans, z.unsqueeze(-1))
247
+ ).squeeze(-1)
248
+ return w
249
+
250
+ @staticmethod
251
+ def projection_op(x, w):
252
+ """Projects points to a plane defined by w."""
253
+ # Normal vector to the tangent plane.
254
+ n = torch.cat([w[..., :2], -torch.ones_like(w[..., 2:3])], dim=1)
255
+
256
+ pn = torch.sum(x * n, dim=-1, keepdim=True)
257
+ nn = torch.sum(n * n, dim=-1, keepdim=True)
258
+
259
+ # Offset.
260
+ d = w[..., 2:3]
261
+
262
+ # Projection of x onto n.
263
+ projn_x = ((pn + d) / nn) * n
264
+
265
+ # Remove component in the normal direction.
266
+ return x - projn_x
267
+
268
+ def get_timepoint_data(self):
269
+ """Return data organized by timepoints for visualization"""
270
+ return {
271
+ 't0': self.coords_t0,
272
+ 't1': self.coords_t1,
273
+ 'time_labels': self.time_labels
274
+ }
dataloaders/mouse_data.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
7
+ import numpy as np
8
+ from scipy.spatial import cKDTree
9
+ import math
10
+ from functools import partial
11
+ from sklearn.cluster import KMeans, DBSCAN
12
+ import matplotlib.pyplot as plt
13
+ import pandas as pd
14
+ from torch.utils.data import TensorDataset
15
+
16
+ class WeightedBranchedCellDataModule(pl.LightningDataModule):
17
+ def __init__(self, args):
18
+ super().__init__()
19
+ self.save_hyperparameters()
20
+
21
+ self.data_path = args.data_path
22
+ self.batch_size = args.batch_size
23
+ self.max_dim = args.dim
24
+ self.whiten = args.whiten
25
+ self.k = 20
26
+ self.n_samples = 1429
27
+ self.num_timesteps = 2 # t=0, t=1, t=2
28
+ self.split_ratios = args.split_ratios
29
+ self.metric_clusters = args.metric_clusters
30
+ self.args = args
31
+ self._prepare_data()
32
+
33
+
34
+ def _prepare_data(self):
35
+ print("Preparing cell data in BranchedCellDataModule")
36
+
37
+ df = pd.read_csv(self.data_path)
38
+
39
+ # Build dictionary of coordinates by time
40
+ coords_by_t = {
41
+ t: df[df["samples"] == t][["x1","x2"]].values
42
+ for t in sorted(df["samples"].unique())
43
+ }
44
+ n0 = coords_by_t[0].shape[0] # Number of T=0 points
45
+ self.n_samples = n0 # Update n_samples to match actual data if changes
46
+
47
+ # Cluster the t=2 cells into two branches
48
+ km = KMeans(n_clusters=2, random_state=42).fit(coords_by_t[2])
49
+ df2 = df[df["samples"] == 2].copy()
50
+ df2["branch"] = km.labels_
51
+
52
+ cluster_counts = df2["branch"].value_counts().sort_index()
53
+ print(cluster_counts)
54
+
55
+ # Sample n0 points from each branch
56
+ endpoints = {}
57
+ for b in (0, 1):
58
+ endpoints[b] = (
59
+ df2[df2["branch"] == b]
60
+ .sample(n=n0, random_state=42)[["x1","x2"]]
61
+ .values
62
+ )
63
+
64
+ x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
65
+ x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
66
+ x1_1 = torch.tensor(endpoints[0], dtype=torch.float32) # Branch index
67
+ x1_2 = torch.tensor(endpoints[1], dtype=torch.float32) # Branch index
68
+
69
+ self.coords_t0 = x0
70
+ self.coords_t1 = x_inter
71
+ self.coords_t2_1 = x1_1
72
+ self.coords_t2_2 = x1_2
73
+ self.time_labels = np.concatenate([
74
+ np.zeros(len(self.coords_t0)), # t=0
75
+ np.ones(len(self.coords_t1)), # t=1
76
+ np.ones(len(self.coords_t2_1)) * 2, # t=1
77
+ np.ones(len(self.coords_t2_2)) * 2,
78
+ ])
79
+
80
+ split_index = int(n0 * self.split_ratios[0])
81
+
82
+ if n0 - split_index < self.batch_size:
83
+ split_index = n0 - self.batch_size
84
+
85
+ train_x0 = x0[:split_index]
86
+ val_x0 = x0[split_index:]
87
+ train_x1_1 = x1_1[:split_index]
88
+ val_x1_1 = x1_1[split_index:]
89
+ train_x1_2 = x1_2[:split_index]
90
+ val_x1_2 = x1_2[split_index:]
91
+
92
+ self.val_x0 = val_x0
93
+
94
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
95
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.5)
96
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.5)
97
+
98
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
99
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.5)
100
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.5)
101
+
102
+ if self.n_samples - split_index < self.batch_size:
103
+ split_index = self.n_samples - self.batch_size
104
+
105
+ self.train_dataloaders = {
106
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
107
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
108
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
109
+ }
110
+
111
+ self.val_dataloaders = {
112
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
113
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
114
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
115
+ }
116
+
117
+ all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
118
+ self.dataset = torch.tensor(all_data, dtype=torch.float32)
119
+ self.tree = cKDTree(all_data)
120
+
121
+ # if whitening is enabled, need to apply this to the full dataset
122
+ #if self.whiten:
123
+ #self.scaler = StandardScaler()
124
+ #self.dataset = torch.tensor(
125
+ #self.scaler.fit_transform(all_data), dtype=torch.float32
126
+ #)
127
+
128
+ self.test_dataloaders = {
129
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
130
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
131
+ }
132
+
133
+ # Metric Dataloader
134
+ # K-means clustering of ALL points into 2 groups
135
+ if self.metric_clusters == 3:
136
+ km_all = KMeans(n_clusters=3, random_state=45).fit(self.dataset.numpy())
137
+ cluster_labels = km_all.labels_
138
+
139
+ cluster_0_mask = cluster_labels == 0
140
+ cluster_1_mask = cluster_labels == 1
141
+ cluster_2_mask = cluster_labels == 2
142
+
143
+ samples = self.dataset.cpu().numpy()
144
+
145
+ cluster_0_data = samples[cluster_0_mask]
146
+ cluster_1_data = samples[cluster_1_mask]
147
+ cluster_2_data = samples[cluster_2_mask]
148
+
149
+ self.metric_samples_dataloaders = [
150
+ DataLoader(
151
+ torch.tensor(cluster_1_data, dtype=torch.float32),
152
+ batch_size=cluster_1_data.shape[0],
153
+ shuffle=False,
154
+ drop_last=False,
155
+ ),
156
+ DataLoader(
157
+ torch.tensor(cluster_2_data, dtype=torch.float32),
158
+ batch_size=cluster_2_data.shape[0],
159
+ shuffle=False,
160
+ drop_last=False,
161
+ ),
162
+
163
+ DataLoader(
164
+ torch.tensor(cluster_0_data, dtype=torch.float32),
165
+ batch_size=cluster_0_data.shape[0],
166
+ shuffle=False,
167
+ drop_last=False,
168
+ ),
169
+ ]
170
+ else:
171
+ km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
172
+ cluster_labels = km_all.labels_
173
+
174
+ cluster_0_mask = cluster_labels == 0
175
+ cluster_1_mask = cluster_labels == 1
176
+
177
+ samples = self.dataset.cpu().numpy()
178
+
179
+ cluster_0_data = samples[cluster_0_mask]
180
+ cluster_1_data = samples[cluster_1_mask]
181
+
182
+ self.metric_samples_dataloaders = [
183
+ DataLoader(
184
+ torch.tensor(cluster_1_data, dtype=torch.float32),
185
+ batch_size=cluster_1_data.shape[0],
186
+ shuffle=False,
187
+ drop_last=False,
188
+ ),
189
+ DataLoader(
190
+ torch.tensor(cluster_0_data, dtype=torch.float32),
191
+ batch_size=cluster_0_data.shape[0],
192
+ shuffle=False,
193
+ drop_last=False,
194
+ ),
195
+ ]
196
+
197
+
198
+ def train_dataloader(self):
199
+ combined_loaders = {
200
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
201
+ "metric_samples": CombinedLoader(
202
+ self.metric_samples_dataloaders, mode="min_size"
203
+ ),
204
+ }
205
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
206
+
207
+ def val_dataloader(self):
208
+ combined_loaders = {
209
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
210
+ "metric_samples": CombinedLoader(
211
+ self.metric_samples_dataloaders, mode="min_size"
212
+ ),
213
+ }
214
+
215
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
216
+
217
+ def test_dataloader(self):
218
+ combined_loaders = {
219
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
220
+ "metric_samples": CombinedLoader(
221
+ self.metric_samples_dataloaders, mode="min_size"
222
+ ),
223
+ }
224
+
225
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
226
+
227
+ def get_manifold_proj(self, points):
228
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
229
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
230
+
231
+ @staticmethod
232
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
233
+ """
234
+ Apply local smoothing based on k-nearest neighbors in the full dataset
235
+ This replaces the plane projection for 2D manifold regularization
236
+ """
237
+ points_np = x.detach().cpu().numpy()
238
+ _, idx = tree.query(points_np, k=k)
239
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
240
+
241
+ # Compute weighted average of neighbors
242
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
243
+ weights = torch.exp(-dists / temp)
244
+ weights = weights / weights.sum(dim=1, keepdim=True)
245
+
246
+ # Weighted average of neighbors
247
+ smoothed = (weights * nearest_pts).sum(dim=1)
248
+
249
+ # Blend original point with smoothed version
250
+ alpha = 0.3 # How much smoothing to apply
251
+ return (1 - alpha) * x + alpha * smoothed
252
+
253
+ def get_timepoint_data(self):
254
+ """Return data organized by timepoints for visualization"""
255
+ return {
256
+ 't0': self.coords_t0,
257
+ 't1': self.coords_t1,
258
+ 't2_1': self.coords_t2_1,
259
+ 't2_2': self.coords_t2_2,
260
+ 'time_labels': self.time_labels
261
+ }
262
+
263
+
264
+
265
+ class SingleBranchCellDataModule(pl.LightningDataModule):
266
+ def __init__(self, args):
267
+ super().__init__()
268
+ self.save_hyperparameters()
269
+
270
+ self.data_path = args.data_path
271
+ self.batch_size = args.batch_size
272
+ self.max_dim = args.dim
273
+ self.whiten = args.whiten
274
+ self.k = 20
275
+ self.n_samples = 1429
276
+ self.num_timesteps = 3 # t=0, t=1, t=2
277
+ self.split_ratios = args.split_ratios
278
+ self.metric_clusters = 3
279
+ self.args = args
280
+ self._prepare_data()
281
+
282
+
283
+ def _prepare_data(self):
284
+ print("Preparing cell data in BranchedCellDataModule")
285
+
286
+ df = pd.read_csv(self.data_path)
287
+
288
+ # Build dictionary of coordinates by time
289
+ coords_by_t = {
290
+ t: df[df["samples"] == t][["x1","x2"]].values
291
+ for t in sorted(df["samples"].unique())
292
+ }
293
+ n0 = coords_by_t[0].shape[0] # Number of T=0 points
294
+ self.n_samples = n0 # Update n_samples to match actual data if changes
295
+
296
+ x0 = torch.tensor(coords_by_t[0], dtype=torch.float32) # T=0 coordinates index
297
+ x_inter = torch.tensor(coords_by_t[1], dtype=torch.float32)
298
+ x1 = torch.tensor(coords_by_t[2], dtype=torch.float32) # Branch index
299
+
300
+ # Store for get_timepoint_data()
301
+ self.coords_t0 = x0
302
+ self.coords_t1 = x_inter
303
+ self.coords_t2 = x1
304
+ self.time_labels = np.concatenate([
305
+ np.zeros(len(x0)),
306
+ np.ones(len(x_inter)),
307
+ np.ones(len(x1)) * 2,
308
+ ])
309
+
310
+ split_index = int(n0 * self.split_ratios[0])
311
+
312
+ if n0 - split_index < self.batch_size:
313
+ split_index = n0 - self.batch_size
314
+
315
+ train_x0 = x0[:split_index]
316
+ val_x0 = x0[split_index:]
317
+ train_x1 = x1[:split_index]
318
+ val_x1 = x1[split_index:]
319
+
320
+ self.val_x0 = val_x0
321
+
322
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
323
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=0.5)
324
+
325
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
326
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=0.5)
327
+
328
+ if self.n_samples - split_index < self.batch_size:
329
+ split_index = self.n_samples - self.batch_size
330
+
331
+ self.train_dataloaders = {
332
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
333
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
334
+ }
335
+
336
+ self.val_dataloaders = {
337
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
338
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
339
+ }
340
+
341
+ all_data = np.vstack([coords_by_t[t] for t in sorted(coords_by_t.keys())])
342
+ self.dataset = torch.tensor(all_data, dtype=torch.float32)
343
+ self.tree = cKDTree(all_data)
344
+
345
+ # if whitening is enabled, need to apply this to the full dataset
346
+ if self.whiten:
347
+ self.scaler = StandardScaler()
348
+ self.dataset = torch.tensor(
349
+ self.scaler.fit_transform(all_data), dtype=torch.float32
350
+ )
351
+
352
+ self.test_dataloaders = {
353
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
354
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
355
+ }
356
+
357
+ # Metric Dataloader
358
+ # K-means clustering of ALL points into 2 groups
359
+ km_all = KMeans(n_clusters=2, random_state=45).fit(self.dataset.numpy())
360
+ cluster_labels = km_all.labels_
361
+
362
+ cluster_0_mask = cluster_labels == 0
363
+ cluster_1_mask = cluster_labels == 1
364
+
365
+ samples = self.dataset.cpu().numpy()
366
+
367
+ cluster_0_data = samples[cluster_0_mask]
368
+ cluster_1_data = samples[cluster_1_mask]
369
+
370
+ self.metric_samples_dataloaders = [
371
+ DataLoader(
372
+ torch.tensor(cluster_1_data, dtype=torch.float32),
373
+ batch_size=cluster_1_data.shape[0],
374
+ shuffle=False,
375
+ drop_last=False,
376
+ ),
377
+ DataLoader(
378
+ torch.tensor(cluster_0_data, dtype=torch.float32),
379
+ batch_size=cluster_0_data.shape[0],
380
+ shuffle=False,
381
+ drop_last=False,
382
+ ),
383
+ ]
384
+
385
+
386
+ def train_dataloader(self):
387
+ combined_loaders = {
388
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
389
+ "metric_samples": CombinedLoader(
390
+ self.metric_samples_dataloaders, mode="min_size"
391
+ ),
392
+ }
393
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
394
+
395
+ def val_dataloader(self):
396
+ combined_loaders = {
397
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
398
+ "metric_samples": CombinedLoader(
399
+ self.metric_samples_dataloaders, mode="min_size"
400
+ ),
401
+ }
402
+
403
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
404
+
405
+ def test_dataloader(self):
406
+ combined_loaders = {
407
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
408
+ "metric_samples": CombinedLoader(
409
+ self.metric_samples_dataloaders, mode="min_size"
410
+ ),
411
+ }
412
+
413
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
414
+
415
+ def get_manifold_proj(self, points):
416
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
417
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
418
+
419
+ @staticmethod
420
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
421
+ """
422
+ Apply local smoothing based on k-nearest neighbors in the full dataset
423
+ This replaces the plane projection for 2D manifold regularization
424
+ """
425
+ points_np = x.detach().cpu().numpy()
426
+ _, idx = tree.query(points_np, k=k)
427
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
428
+
429
+ # Compute weighted average of neighbors
430
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
431
+ weights = torch.exp(-dists / temp)
432
+ weights = weights / weights.sum(dim=1, keepdim=True)
433
+
434
+ # Weighted average of neighbors
435
+ smoothed = (weights * nearest_pts).sum(dim=1)
436
+
437
+ # Blend original point with smoothed version
438
+ alpha = 0.3 # How much smoothing to apply
439
+ return (1 - alpha) * x + alpha * smoothed
440
+
441
+ def get_timepoint_data(self):
442
+ """Return data organized by timepoints for visualization"""
443
+ return {
444
+ 't0': self.coords_t0,
445
+ 't1': self.coords_t1,
446
+ 't2': self.coords_t2,
447
+ 'time_labels': self.time_labels
448
+ }
449
+
450
+ """def get_datamodule():
451
+ datamodule = WeightedBranchedCellDataModule(args)
452
+ datamodule.setup(stage="fit")
453
+ return datamodule"""
dataloaders/three_branch_data.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
7
+ import pandas as pd
8
+ import numpy as np
9
+ from functools import partial
10
+ from scipy.spatial import cKDTree
11
+ from sklearn.cluster import KMeans
12
+ from torch.utils.data import TensorDataset
13
+
14
+ class ThreeBranchTahoeDataModule(pl.LightningDataModule):
15
+ def __init__(self, args):
16
+ super().__init__()
17
+ self.save_hyperparameters()
18
+
19
+ self.batch_size = args.batch_size
20
+ self.max_dim = args.dim
21
+ self.whiten = args.whiten
22
+ self.split_ratios = args.split_ratios
23
+ self.num_timesteps = 2
24
+ self.data_path = f"{args.working_dir}/data/Trametinib_5.0uM_pca_and_leidenumap_labels.csv"
25
+ self.args = args
26
+
27
+ self._prepare_data()
28
+
29
+ def _prepare_data(self):
30
+ df = pd.read_csv(self.data_path, comment='#')
31
+ df = df.iloc[:, 1:]
32
+ df = df.replace('', np.nan)
33
+ pc_cols = df.columns[:50]
34
+ for col in pc_cols:
35
+ df[col] = pd.to_numeric(df[col], errors='coerce')
36
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
37
+ leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
38
+
39
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
40
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
41
+
42
+ dmso_data = df[dmso_mask].copy()
43
+ clonidine_data = df[clonidine_mask].copy()
44
+
45
+ # Updated to include all three clusters: 0, 4, and 6
46
+ top_clonidine_clusters = ['1.0', '3.0', '5.0']
47
+
48
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
49
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
50
+ x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
51
+
52
+ x1_1_coords = x1_1_data[pc_cols].values
53
+ x1_2_coords = x1_2_data[pc_cols].values
54
+ x1_3_coords = x1_3_data[pc_cols].values
55
+
56
+ x1_1_coords = x1_1_coords.astype(float)
57
+ x1_2_coords = x1_2_coords.astype(float)
58
+ x1_3_coords = x1_3_coords.astype(float)
59
+
60
+ # Target size is now the minimum across all three endpoint clusters
61
+ target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
62
+
63
+ # Helper function to select points closest to centroid
64
+ def select_closest_to_centroid(coords, target_size):
65
+ if len(coords) <= target_size:
66
+ return coords
67
+
68
+ # Calculate centroid
69
+ centroid = np.mean(coords, axis=0)
70
+
71
+ # Calculate distances to centroid
72
+ distances = np.linalg.norm(coords - centroid, axis=1)
73
+
74
+ # Get indices of closest points
75
+ closest_indices = np.argsort(distances)[:target_size]
76
+
77
+ return coords[closest_indices]
78
+
79
+ # Sample all endpoint clusters to target size using centroid-based selection
80
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
81
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
82
+ x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
83
+
84
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
85
+
86
+ # DMSO (unchanged)
87
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
88
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
89
+
90
+ dmso_coords = dmso_cluster_data[pc_cols].values
91
+
92
+ # Random sampling from largest DMSO cluster to match target size
93
+ # For DMSO, we'll also use centroid-based selection for consistency
94
+ if len(dmso_coords) >= target_size:
95
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
96
+ else:
97
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
98
+ remaining_needed = target_size - len(dmso_coords)
99
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
100
+ other_dmso_coords = other_dmso_data[pc_cols].values
101
+
102
+ if len(other_dmso_coords) >= remaining_needed:
103
+ # Select closest to centroid from other DMSO cells
104
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
105
+ x0_coords = np.vstack([dmso_coords, other_selected])
106
+ else:
107
+ # Use all available DMSO cells and reduce target size
108
+ all_dmso_coords = dmso_data[pc_cols].values
109
+ target_size = min(target_size, len(all_dmso_coords))
110
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
111
+
112
+ # Re-select endpoint clusters with updated target size
113
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
114
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
115
+ x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
116
+
117
+
118
+ self.n_samples = target_size
119
+
120
+ # for plotting
121
+ self.coords_t0 = torch.tensor(x0_coords, dtype=torch.float32)
122
+ self.coords_t1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
123
+ self.coords_t1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
124
+ self.coords_t1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
125
+
126
+ self.time_labels = np.concatenate([
127
+ np.zeros(len(self.coords_t0)), # t=0
128
+ np.ones(len(self.coords_t1_1)), # t=1
129
+ np.ones(len(self.coords_t1_2)), # t=1
130
+ np.ones(len(self.coords_t1_3)), # t=1
131
+ ])
132
+
133
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
134
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
135
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
136
+ x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
137
+
138
+ split_index = int(target_size * self.split_ratios[0])
139
+
140
+ if target_size - split_index < self.batch_size:
141
+ split_index = target_size - self.batch_size
142
+
143
+ train_x0 = x0[:split_index]
144
+ val_x0 = x0[split_index:]
145
+ train_x1_1 = x1_1[:split_index]
146
+ val_x1_1 = x1_1[split_index:]
147
+ train_x1_2 = x1_2[:split_index]
148
+ val_x1_2 = x1_2[split_index:]
149
+ train_x1_3 = x1_3[:split_index]
150
+ val_x1_3 = x1_3[split_index:]
151
+
152
+ self.val_x0 = val_x0
153
+
154
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
155
+ train_x1_1_weights = torch.full((train_x1_1.shape[0], 1), fill_value=0.603)
156
+ train_x1_2_weights = torch.full((train_x1_2.shape[0], 1), fill_value=0.255)
157
+ train_x1_3_weights = torch.full((train_x1_3.shape[0], 1), fill_value=0.142)
158
+
159
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
160
+ val_x1_1_weights = torch.full((val_x1_1.shape[0], 1), fill_value=0.603)
161
+ val_x1_2_weights = torch.full((val_x1_2.shape[0], 1), fill_value=0.255)
162
+ val_x1_3_weights = torch.full((val_x1_3.shape[0], 1), fill_value=0.142)
163
+
164
+ # Updated train dataloaders to include x1_3
165
+ self.train_dataloaders = {
166
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
167
+ "x1_1": DataLoader(TensorDataset(train_x1_1, train_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
168
+ "x1_2": DataLoader(TensorDataset(train_x1_2, train_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
169
+ "x1_3": DataLoader(TensorDataset(train_x1_3, train_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
170
+ }
171
+
172
+ # Updated val dataloaders to include x1_3
173
+ self.val_dataloaders = {
174
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
175
+ "x1_1": DataLoader(TensorDataset(val_x1_1, val_x1_1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
176
+ "x1_2": DataLoader(TensorDataset(val_x1_2, val_x1_2_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
177
+ "x1_3": DataLoader(TensorDataset(val_x1_3, val_x1_3_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
178
+ }
179
+
180
+ all_coords = df[pc_cols].dropna().values.astype(float)
181
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
182
+ self.tree = cKDTree(all_coords)
183
+
184
+ self.test_dataloaders = {
185
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
186
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
187
+ }
188
+
189
+ # Updated metric samples - now using 4 clusters instead of 3
190
+ #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
191
+ km_all = KMeans(n_clusters=4, random_state=0).fit(self.dataset[:, :3].numpy())
192
+
193
+ cluster_labels = km_all.labels_
194
+
195
+ cluster_0_mask = cluster_labels == 0
196
+ cluster_1_mask = cluster_labels == 1
197
+ cluster_2_mask = cluster_labels == 2
198
+ cluster_3_mask = cluster_labels == 3
199
+
200
+ samples = self.dataset.cpu().numpy()
201
+
202
+ cluster_0_data = samples[cluster_0_mask]
203
+ cluster_1_data = samples[cluster_1_mask]
204
+ cluster_2_data = samples[cluster_2_mask]
205
+ cluster_3_data = samples[cluster_3_mask]
206
+
207
+ self.metric_samples_dataloaders = [
208
+ DataLoader(
209
+ torch.tensor(cluster_1_data, dtype=torch.float32),
210
+ batch_size=cluster_1_data.shape[0],
211
+ shuffle=False,
212
+ drop_last=False,
213
+ ),
214
+ DataLoader(
215
+ torch.tensor(cluster_3_data, dtype=torch.float32),
216
+ batch_size=cluster_3_data.shape[0],
217
+ shuffle=False,
218
+ drop_last=False,
219
+ ),
220
+
221
+
222
+ DataLoader(
223
+ torch.tensor(cluster_2_data, dtype=torch.float32),
224
+ batch_size=cluster_2_data.shape[0],
225
+ shuffle=False,
226
+ drop_last=False,
227
+ ),
228
+ DataLoader(
229
+ torch.tensor(cluster_0_data, dtype=torch.float32),
230
+ batch_size=cluster_0_data.shape[0],
231
+ shuffle=False,
232
+ drop_last=False,
233
+ ),
234
+ ]
235
+
236
+ def train_dataloader(self):
237
+ combined_loaders = {
238
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
239
+ "metric_samples": CombinedLoader(
240
+ self.metric_samples_dataloaders, mode="min_size"
241
+ ),
242
+ }
243
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
244
+
245
+ def val_dataloader(self):
246
+ combined_loaders = {
247
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
248
+ "metric_samples": CombinedLoader(
249
+ self.metric_samples_dataloaders, mode="min_size"
250
+ ),
251
+ }
252
+
253
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
254
+
255
+ def test_dataloader(self):
256
+ combined_loaders = {
257
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
258
+ "metric_samples": CombinedLoader(
259
+ self.metric_samples_dataloaders, mode="min_size"
260
+ ),
261
+ }
262
+
263
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
264
+
265
+ def get_manifold_proj(self, points):
266
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
267
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
268
+
269
+ @staticmethod
270
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
271
+ """
272
+ Apply local smoothing based on k-nearest neighbors in the full dataset
273
+ This replaces the plane projection for 2D manifold regularization
274
+ """
275
+ points_np = x.detach().cpu().numpy()
276
+ _, idx = tree.query(points_np, k=k)
277
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
278
+
279
+ # Compute weighted average of neighbors
280
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
281
+ weights = torch.exp(-dists / temp)
282
+ weights = weights / weights.sum(dim=1, keepdim=True)
283
+
284
+ # Weighted average of neighbors
285
+ smoothed = (weights * nearest_pts).sum(dim=1)
286
+
287
+ # Blend original point with smoothed version
288
+ alpha = 0.3 # How much smoothing to apply
289
+ return (1 - alpha) * x + alpha * smoothed
290
+
291
+ def get_timepoint_data(self):
292
+ """Return data organized by timepoints for visualization"""
293
+ return {
294
+ 't0': self.coords_t0,
295
+ 't1_1': self.coords_t1_1,
296
+ 't1_2': self.coords_t1_2,
297
+ 't1_3': self.coords_t1_3,
298
+ 'time_labels': self.time_labels
299
+ }
300
+
301
+ def get_datamodule():
302
+ from plot.parsers_tahoe import parse_args
303
+ args = parse_args()
304
+ datamodule = ThreeBranchTahoeDataModule(args)
305
+ datamodule.setup(stage="fit")
306
+ return datamodule
dataloaders/trametinib_single.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
7
+ import pandas as pd
8
+ import numpy as np
9
+ from functools import partial
10
+ from scipy.spatial import cKDTree
11
+ from sklearn.cluster import KMeans
12
+ from torch.utils.data import TensorDataset
13
+
14
+
15
+ class TrametinibSingleBranchDataModule(pl.LightningDataModule):
16
+ def __init__(self, args):
17
+ super().__init__()
18
+ self.save_hyperparameters()
19
+
20
+ self.batch_size = args.batch_size
21
+ self.max_dim = args.dim
22
+ self.whiten = args.whiten
23
+ self.split_ratios = args.split_ratios
24
+ self.num_timesteps = 2
25
+ self.data_path = args.data_path
26
+ self.args = args
27
+
28
+ self._prepare_data()
29
+
30
+ def _prepare_data(self):
31
+ df = pd.read_csv(self.data_path, comment='#')
32
+ df = df.iloc[:, 1:]
33
+ df = df.replace('', np.nan)
34
+ pc_cols = df.columns[:50]
35
+ for col in pc_cols:
36
+ df[col] = pd.to_numeric(df[col], errors='coerce')
37
+ leiden_dmso_col = 'leiden_DMSO_TF_0.0uM'
38
+ leiden_clonidine_col = 'leiden_Trametinib_5.0uM'
39
+
40
+ dmso_mask = df[leiden_dmso_col].notna() # Has leiden value in DMSO column
41
+ clonidine_mask = df[leiden_clonidine_col].notna() # Has leiden value in Clonidine column
42
+
43
+ dmso_data = df[dmso_mask].copy()
44
+ clonidine_data = df[clonidine_mask].copy()
45
+
46
+ # Updated to include all three clusters: 0, 4, and 6
47
+ top_clonidine_clusters = ['1.0', '3.0', '5.0']
48
+
49
+ x1_1_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[0]]
50
+ x1_2_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[1]]
51
+ x1_3_data = clonidine_data[clonidine_data[leiden_clonidine_col].astype(str) == top_clonidine_clusters[2]]
52
+
53
+ x1_1_coords = x1_1_data[pc_cols].values
54
+ x1_2_coords = x1_2_data[pc_cols].values
55
+ x1_3_coords = x1_3_data[pc_cols].values
56
+
57
+ x1_1_coords = x1_1_coords.astype(float)
58
+ x1_2_coords = x1_2_coords.astype(float)
59
+ x1_3_coords = x1_3_coords.astype(float)
60
+
61
+ # Target size is now the minimum across all three endpoint clusters
62
+ target_size = min(len(x1_1_coords), len(x1_2_coords), len(x1_3_coords))
63
+
64
+ # Helper function to select points closest to centroid
65
+ def select_closest_to_centroid(coords, target_size):
66
+ if len(coords) <= target_size:
67
+ return coords
68
+
69
+ # Calculate centroid
70
+ centroid = np.mean(coords, axis=0)
71
+
72
+ # Calculate distances to centroid
73
+ distances = np.linalg.norm(coords - centroid, axis=1)
74
+
75
+ # Get indices of closest points
76
+ closest_indices = np.argsort(distances)[:target_size]
77
+
78
+ return coords[closest_indices]
79
+
80
+ # Sample all endpoint clusters to target size using centroid-based selection
81
+ x1_1_coords = select_closest_to_centroid(x1_1_coords, target_size)
82
+ x1_2_coords = select_closest_to_centroid(x1_2_coords, target_size)
83
+ x1_3_coords = select_closest_to_centroid(x1_3_coords, target_size)
84
+
85
+ dmso_cluster_counts = dmso_data[leiden_dmso_col].value_counts()
86
+
87
+ # DMSO (unchanged)
88
+ largest_dmso_cluster = dmso_cluster_counts.index[0]
89
+ dmso_cluster_data = dmso_data[dmso_data[leiden_dmso_col] == largest_dmso_cluster]
90
+
91
+ dmso_coords = dmso_cluster_data[pc_cols].values
92
+
93
+ # Random sampling from largest DMSO cluster to match target size
94
+ # For DMSO, we'll also use centroid-based selection for consistency
95
+ if len(dmso_coords) >= target_size:
96
+ x0_coords = select_closest_to_centroid(dmso_coords, target_size)
97
+ else:
98
+ # If largest cluster is smaller than target, use all of it and pad with other DMSO cells
99
+ remaining_needed = target_size - len(dmso_coords)
100
+ other_dmso_data = dmso_data[dmso_data[leiden_dmso_col] != largest_dmso_cluster]
101
+ other_dmso_coords = other_dmso_data[pc_cols].values
102
+
103
+ if len(other_dmso_coords) >= remaining_needed:
104
+ # Select closest to centroid from other DMSO cells
105
+ other_selected = select_closest_to_centroid(other_dmso_coords, remaining_needed)
106
+ x0_coords = np.vstack([dmso_coords, other_selected])
107
+ else:
108
+ # Use all available DMSO cells and reduce target size
109
+ all_dmso_coords = dmso_data[pc_cols].values
110
+ target_size = min(target_size, len(all_dmso_coords))
111
+ x0_coords = select_closest_to_centroid(all_dmso_coords, target_size)
112
+
113
+ # Re-select endpoint clusters with updated target size
114
+ x1_1_coords = select_closest_to_centroid(x1_1_data[pc_cols].values.astype(float), target_size)
115
+ x1_2_coords = select_closest_to_centroid(x1_2_data[pc_cols].values.astype(float), target_size)
116
+ x1_3_coords = select_closest_to_centroid(x1_3_data[pc_cols].values.astype(float), target_size)
117
+
118
+ self.n_samples = target_size
119
+
120
+ # for plotting
121
+
122
+
123
+ x0 = torch.tensor(x0_coords, dtype=torch.float32)
124
+ x1_1 = torch.tensor(x1_1_coords, dtype=torch.float32)
125
+ x1_2 = torch.tensor(x1_2_coords, dtype=torch.float32)
126
+ x1_3 = torch.tensor(x1_3_coords, dtype=torch.float32)
127
+ x1 = torch.cat([x1_1, x1_2, x1_3], dim=0)
128
+
129
+ self.coords_t0 = x0
130
+ self.coords_t1 = x1
131
+
132
+ self.time_labels = np.concatenate([
133
+ np.zeros(len(self.coords_t0)), # t=0
134
+ np.ones(len(self.coords_t1)), # t=1
135
+ ])
136
+
137
+ split_index = int(target_size * self.split_ratios[0])
138
+
139
+ if target_size - split_index < self.batch_size:
140
+ split_index = target_size - self.batch_size
141
+
142
+ train_x0 = x0[:split_index]
143
+ val_x0 = x0[split_index:]
144
+ train_x1 = x1_1[:split_index]
145
+ val_x1 = x1_1[split_index:]
146
+
147
+ self.val_x0 = val_x0
148
+
149
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
150
+ train_x1_weights = torch.full((train_x1.shape[0], 1), fill_value=1.0)
151
+
152
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
153
+ val_x1_weights = torch.full((val_x1.shape[0], 1), fill_value=1.0)
154
+
155
+ # Updated train dataloaders to include x1_3
156
+ self.train_dataloaders = {
157
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
158
+ "x1": DataLoader(TensorDataset(train_x1, train_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
159
+ }
160
+
161
+ # Updated val dataloaders to include x1_3
162
+ self.val_dataloaders = {
163
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
164
+ "x1": DataLoader(TensorDataset(val_x1, val_x1_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
165
+ }
166
+
167
+ all_coords = df[pc_cols].dropna().values.astype(float)
168
+ self.dataset = torch.tensor(all_coords, dtype=torch.float32)
169
+ self.tree = cKDTree(all_coords)
170
+
171
+ self.test_dataloaders = {
172
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
173
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
174
+ }
175
+
176
+ # Updated metric samples - now using 4 clusters instead of 3
177
+ #km_all = KMeans(n_clusters=4, random_state=42).fit(self.dataset.numpy())
178
+ km_all = KMeans(n_clusters=2, random_state=0).fit(self.dataset[:, :3].numpy())
179
+
180
+ cluster_labels = km_all.labels_
181
+
182
+ cluster_0_mask = cluster_labels == 0
183
+ cluster_1_mask = cluster_labels == 1
184
+
185
+ samples = self.dataset.cpu().numpy()
186
+
187
+ cluster_0_data = samples[cluster_0_mask]
188
+ cluster_1_data = samples[cluster_1_mask]
189
+
190
+ self.metric_samples_dataloaders = [
191
+ DataLoader(
192
+ torch.tensor(cluster_1_data, dtype=torch.float32),
193
+ batch_size=cluster_1_data.shape[0],
194
+ shuffle=False,
195
+ drop_last=False,
196
+ ),
197
+ DataLoader(
198
+ torch.tensor(cluster_0_data, dtype=torch.float32),
199
+ batch_size=cluster_0_data.shape[0],
200
+ shuffle=False,
201
+ drop_last=False,
202
+ ),
203
+ ]
204
+
205
+ def train_dataloader(self):
206
+ combined_loaders = {
207
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
208
+ "metric_samples": CombinedLoader(
209
+ self.metric_samples_dataloaders, mode="min_size"
210
+ ),
211
+ }
212
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
213
+
214
+ def val_dataloader(self):
215
+ combined_loaders = {
216
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
217
+ "metric_samples": CombinedLoader(
218
+ self.metric_samples_dataloaders, mode="min_size"
219
+ ),
220
+ }
221
+
222
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
223
+
224
+
225
+
226
+ def test_dataloader(self):
227
+ combined_loaders = {
228
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
229
+ "metric_samples": CombinedLoader(
230
+ self.metric_samples_dataloaders, mode="min_size"
231
+ ),
232
+ }
233
+
234
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
235
+
236
+ def get_manifold_proj(self, points):
237
+ """Adapted for 2D cell data - uses local neighborhood averaging instead of plane fitting"""
238
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
239
+
240
+ @staticmethod
241
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
242
+ """
243
+ Apply local smoothing based on k-nearest neighbors in the full dataset
244
+ This replaces the plane projection for 2D manifold regularization
245
+ """
246
+ points_np = x.detach().cpu().numpy()
247
+ _, idx = tree.query(points_np, k=k)
248
+ nearest_pts = dataset[idx] # Shape: (batch_size, k, 2)
249
+
250
+ # Compute weighted average of neighbors
251
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
252
+ weights = torch.exp(-dists / temp)
253
+ weights = weights / weights.sum(dim=1, keepdim=True)
254
+
255
+ # Weighted average of neighbors
256
+ smoothed = (weights * nearest_pts).sum(dim=1)
257
+
258
+ # Blend original point with smoothed version
259
+ alpha = 0.3 # How much smoothing to apply
260
+ return (1 - alpha) * x + alpha * smoothed
261
+
262
+ def get_timepoint_data(self):
263
+ """Return data organized by timepoints for visualization"""
264
+ return {
265
+ 't0': self.coords_t0,
266
+ 't1': self.coords_t1,
267
+ 'time_labels': self.time_labels
268
+ }
dataloaders/veres_leiden_data.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ from sklearn.preprocessing import StandardScaler
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader
6
+ from lightning.pytorch.utilities.combined_loader import CombinedLoader
7
+ import numpy as np
8
+ from scipy.spatial import cKDTree
9
+ import math
10
+ from functools import partial
11
+ import matplotlib.pyplot as plt
12
+ import pandas as pd
13
+ from torch.utils.data import TensorDataset
14
+ from sklearn.neighbors import kneighbors_graph
15
+ import igraph as ig
16
+ from leidenalg import find_partition, ModularityVertexPartition
17
+
18
+ class WeightedBranchedVeresDataModule(pl.LightningDataModule):
19
+
20
+ def __init__(self, args):
21
+ super().__init__()
22
+ self.save_hyperparameters()
23
+
24
+ self.data_path = args.data_path
25
+ self.batch_size = args.batch_size
26
+ self.max_dim = args.dim
27
+ self.whiten = args.whiten
28
+ self.k = 20
29
+ self.num_timesteps = 8
30
+ # initial placeholder, will be set by clustering result
31
+ self.num_branches = args.branches if hasattr(args, 'branches') else None
32
+ self.split_ratios = args.split_ratios
33
+ self.metric_clusters = args.metric_clusters
34
+ self.discard_small = args.discard if hasattr(args, 'discard') else False
35
+ self.args = args
36
+ self._prepare_data()
37
+
38
+ def _prepare_data(self):
39
+ print("Preparing Veres cell data with Leiden clustering in WeightedBranchedVeresLeidenDataModule")
40
+ df = pd.read_csv(self.data_path)
41
+
42
+ # Build dictionary of coordinates by time
43
+ coords_by_t = {
44
+ t: df[df["samples"] == t].iloc[:, 1:].values # Skip 'samples' column
45
+ for t in sorted(df["samples"].unique())
46
+ }
47
+
48
+ n0 = coords_by_t[0].shape[0]
49
+ self.n_samples = n0
50
+
51
+ print("Timepoint distribution:")
52
+ for t in sorted(coords_by_t.keys()):
53
+ print(f" t={t}: {coords_by_t[t].shape[0]} points")
54
+
55
+ # Leiden clustering on final timepoint
56
+ final_t = max(coords_by_t.keys())
57
+ coords_final = coords_by_t[final_t]
58
+ k = 20
59
+ knn_graph = kneighbors_graph(coords_final, k, mode='connectivity', include_self=False)
60
+ sources, targets = knn_graph.nonzero()
61
+ edgelist = list(zip(sources.tolist(), targets.tolist()))
62
+ graph = ig.Graph(edgelist, directed=False)
63
+ partition = find_partition(graph, ModularityVertexPartition)
64
+ leiden_labels = np.array(partition.membership)
65
+ n_leiden = len(np.unique(leiden_labels))
66
+ print(f"Leiden found {n_leiden} clusters at t={final_t}")
67
+
68
+ df_final = df[df["samples"] == final_t].copy()
69
+ df_final["branch"] = leiden_labels
70
+
71
+ cluster_counts = df_final["branch"].value_counts().sort_index()
72
+ print(f"Branch distribution at t={final_t} (pre-merge):")
73
+ print(cluster_counts)
74
+
75
+ # Merge small clusters to nearest large cluster (by centroid)
76
+ min_cells = 100 # threshold; adjust if needed
77
+ cluster_data_dict = {}
78
+ cluster_sizes = []
79
+ for b in range(n_leiden):
80
+ branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values
81
+ cluster_data_dict[b] = branch_data
82
+ cluster_sizes.append(branch_data.shape[0])
83
+
84
+ large_clusters = [b for b, size in enumerate(cluster_sizes) if size >= min_cells]
85
+ small_clusters = [b for b, size in enumerate(cluster_sizes) if size < min_cells]
86
+
87
+ # If no large cluster exists (all small), treat all clusters as large
88
+ if len(large_clusters) == 0:
89
+ large_clusters = list(range(n_leiden))
90
+ small_clusters = []
91
+
92
+ if self.discard_small:
93
+ # Discard small clusters instead of merging
94
+ print(f"Discarding {len(small_clusters)} small clusters (< {min_cells} cells)")
95
+ # Keep only cells from large clusters
96
+ mask = np.isin(leiden_labels, large_clusters)
97
+ df_final = df_final[mask].copy()
98
+ merged_labels = leiden_labels[mask]
99
+
100
+ # Remap to contiguous ids
101
+ new_ids = np.unique(merged_labels)
102
+ id_map = {old: new for new, old in enumerate(new_ids)}
103
+ merged_labels = np.array([id_map[x] for x in merged_labels])
104
+ n_merged = len(np.unique(merged_labels))
105
+
106
+ df_final["branch"] = merged_labels
107
+ print(f"Kept {n_merged} large clusters")
108
+ else:
109
+ centroids = {b: np.mean(cluster_data_dict[b], axis=0) for b in range(n_leiden) if cluster_data_dict[b].shape[0] > 0}
110
+
111
+ merged_labels = leiden_labels.copy()
112
+ for b in small_clusters:
113
+ if cluster_data_dict[b].shape[0] == 0:
114
+ continue
115
+ # find nearest large cluster
116
+ dists = [np.linalg.norm(centroids[b] - centroids[bl]) for bl in large_clusters]
117
+ nearest_large = large_clusters[int(np.argmin(dists))]
118
+ merged_labels[leiden_labels == b] = nearest_large
119
+
120
+ # remap to contiguous ids
121
+ new_ids = np.unique(merged_labels)
122
+ id_map = {old: new for new, old in enumerate(new_ids)}
123
+ merged_labels = np.array([id_map[x] for x in merged_labels])
124
+ n_merged = len(np.unique(merged_labels))
125
+
126
+ df_final["branch"] = merged_labels
127
+ print(f"Merged into {n_merged} clusters")
128
+
129
+ cluster_counts_merged = df_final["branch"].value_counts().sort_index()
130
+ print(f"Branch distribution at t={final_t} (post-merge):")
131
+ print(cluster_counts_merged)
132
+
133
+ endpoints = {}
134
+ cluster_sizes = []
135
+ for b in range(n_merged):
136
+ branch_data = df_final[df_final["branch"] == b].iloc[:, 1:-1].values
137
+ cluster_sizes.append(branch_data.shape[0])
138
+ replace = branch_data.shape[0] < n0
139
+ sampled_indices = np.random.choice(branch_data.shape[0], size=n0, replace=replace)
140
+ endpoints[b] = branch_data[sampled_indices]
141
+ total_t_final = sum(cluster_sizes)
142
+
143
+ x0 = torch.tensor(coords_by_t[0], dtype=torch.float32)
144
+ self.coords_t0 = x0
145
+ # intermediate timepoints
146
+ self.coords_intermediate = {t: torch.tensor(coords_by_t[t], dtype=torch.float32)
147
+ for t in coords_by_t.keys() if t != 0 and t != final_t}
148
+
149
+ self.branch_endpoints = {b: torch.tensor(endpoints[b], dtype=torch.float32) for b in range(n_merged)}
150
+ self.num_branches = n_merged
151
+
152
+ # time labels (for visualization)
153
+ time_labels_list = [np.zeros(len(self.coords_t0))]
154
+ for t in sorted(self.coords_intermediate.keys()):
155
+ time_labels_list.append(np.ones(len(self.coords_intermediate[t])) * t)
156
+ for b in range(self.num_branches):
157
+ time_labels_list.append(np.ones(len(self.branch_endpoints[b])) * final_t)
158
+ self.time_labels = np.concatenate(time_labels_list)
159
+
160
+ # splits
161
+ split_index = int(n0 * self.split_ratios[0])
162
+ if n0 - split_index < self.batch_size:
163
+ split_index = n0 - self.batch_size
164
+
165
+ train_x0 = x0[:split_index]
166
+ val_x0 = x0[split_index:]
167
+ self.val_x0 = val_x0
168
+
169
+ train_x0_weights = torch.full((train_x0.shape[0], 1), fill_value=1.0)
170
+ val_x0_weights = torch.full((val_x0.shape[0], 1), fill_value=1.0)
171
+
172
+ # branch weights proportional to cluster sizes
173
+ branch_weights = [size / total_t_final for size in cluster_sizes]
174
+
175
+ # Split intermediate timepoints for sequential training support
176
+ train_intermediate = {}
177
+ val_intermediate = {}
178
+ self.train_coords_intermediate = {} # Store training-only intermediate data for MMD
179
+ for t in sorted(self.coords_intermediate.keys()):
180
+ coords_t = self.coords_intermediate[t]
181
+ train_coords_t = coords_t[:split_index]
182
+ val_coords_t = coords_t[split_index:]
183
+ train_weights_t = torch.full((train_coords_t.shape[0], 1), fill_value=1.0)
184
+ val_weights_t = torch.full((val_coords_t.shape[0], 1), fill_value=1.0)
185
+ train_intermediate[f"x{t}"] = (train_coords_t, train_weights_t)
186
+ val_intermediate[f"x{t}"] = (val_coords_t, val_weights_t)
187
+ self.train_coords_intermediate[t] = train_coords_t # Store training data by int key
188
+
189
+ train_loaders = {
190
+ "x0": DataLoader(TensorDataset(train_x0, train_x0_weights), batch_size=self.batch_size, shuffle=True, drop_last=True),
191
+ }
192
+ val_loaders = {
193
+ "x0": DataLoader(TensorDataset(val_x0, val_x0_weights), batch_size=self.batch_size, shuffle=False, drop_last=True),
194
+ }
195
+
196
+ # Add all intermediate timepoints to loaders
197
+ for t_key in sorted(train_intermediate.keys()):
198
+ train_coords_t, train_weights_t = train_intermediate[t_key]
199
+ val_coords_t, val_weights_t = val_intermediate[t_key]
200
+ train_loaders[t_key] = DataLoader(
201
+ TensorDataset(train_coords_t, train_weights_t),
202
+ batch_size=self.batch_size,
203
+ shuffle=True,
204
+ drop_last=True
205
+ )
206
+ val_loaders[t_key] = DataLoader(
207
+ TensorDataset(val_coords_t, val_weights_t),
208
+ batch_size=self.batch_size,
209
+ shuffle=False,
210
+ drop_last=True
211
+ )
212
+
213
+ for b in range(self.num_branches):
214
+ # Calculate split based on this branch's size, not t=0 size
215
+ branch_size = self.branch_endpoints[b].shape[0]
216
+ branch_split_index = int(branch_size * self.split_ratios[0])
217
+ if branch_size - branch_split_index < self.batch_size:
218
+ branch_split_index = max(0, branch_size - self.batch_size)
219
+
220
+ train_branch = self.branch_endpoints[b][:branch_split_index]
221
+ val_branch = self.branch_endpoints[b][branch_split_index:]
222
+ train_branch_weights = torch.full((train_branch.shape[0], 1), fill_value=branch_weights[b])
223
+ val_branch_weights = torch.full((val_branch.shape[0], 1), fill_value=branch_weights[b])
224
+ train_loaders[f"x1_{b+1}"] = DataLoader(
225
+ TensorDataset(train_branch, train_branch_weights),
226
+ batch_size=self.batch_size,
227
+ shuffle=True,
228
+ drop_last=True
229
+ )
230
+ val_loaders[f"x1_{b+1}"] = DataLoader(
231
+ TensorDataset(val_branch, val_branch_weights),
232
+ batch_size=self.batch_size,
233
+ shuffle=True,
234
+ drop_last=True
235
+ )
236
+
237
+ self.train_dataloaders = train_loaders
238
+ self.val_dataloaders = val_loaders
239
+
240
+ # full dataset
241
+ all_data_list = [coords_by_t[t] for t in sorted(coords_by_t.keys())]
242
+ all_data = np.vstack(all_data_list)
243
+ self.dataset = torch.tensor(all_data, dtype=torch.float32)
244
+ self.tree = cKDTree(all_data)
245
+
246
+ self.test_dataloaders = {
247
+ "x0": DataLoader(TensorDataset(self.val_x0, val_x0_weights), batch_size=self.val_x0.shape[0], shuffle=False, drop_last=False),
248
+ "dataset": DataLoader(TensorDataset(self.dataset), batch_size=self.dataset.shape[0], shuffle=False, drop_last=False),
249
+ }
250
+
251
+ # Metric dataloaders: t0 vs (t1..t_final + endpoints)
252
+ cluster_0_data = self.coords_t0.cpu().numpy()
253
+ cluster_1_list = [self.coords_intermediate[t].cpu().numpy() for t in sorted(self.coords_intermediate.keys())]
254
+ cluster_1_list.extend([self.branch_endpoints[b].cpu().numpy() for b in range(self.num_branches)])
255
+ cluster_1_data = np.vstack(cluster_1_list)
256
+
257
+ self.metric_samples_dataloaders = [
258
+ DataLoader(torch.tensor(cluster_0_data, dtype=torch.float32), batch_size=cluster_0_data.shape[0], shuffle=False, drop_last=False),
259
+ DataLoader(torch.tensor(cluster_1_data, dtype=torch.float32), batch_size=cluster_1_data.shape[0], shuffle=False, drop_last=False),
260
+ ]
261
+
262
+ def train_dataloader(self):
263
+ combined_loaders = {
264
+ "train_samples": CombinedLoader(self.train_dataloaders, mode="min_size"),
265
+ "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"),
266
+ }
267
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
268
+
269
+ def val_dataloader(self):
270
+ combined_loaders = {
271
+ "val_samples": CombinedLoader(self.val_dataloaders, mode="min_size"),
272
+ "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"),
273
+ }
274
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
275
+
276
+ def test_dataloader(self):
277
+ combined_loaders = {
278
+ "test_samples": CombinedLoader(self.test_dataloaders, mode="min_size"),
279
+ "metric_samples": CombinedLoader(self.metric_samples_dataloaders, mode="min_size"),
280
+ }
281
+ return CombinedLoader(combined_loaders, mode="max_size_cycle")
282
+
283
+ def get_manifold_proj(self, points):
284
+ return partial(self.local_smoothing_op, tree=self.tree, dataset=self.dataset)
285
+
286
+ @staticmethod
287
+ def local_smoothing_op(x, tree, dataset, k=10, temp=1e-3):
288
+ points_np = x.detach().cpu().numpy()
289
+ _, idx = tree.query(points_np, k=k)
290
+ nearest_pts = dataset[idx]
291
+ dists = (x.unsqueeze(1) - nearest_pts).pow(2).sum(-1, keepdim=True)
292
+ weights = torch.exp(-dists / temp)
293
+ weights = weights / weights.sum(dim=1, keepdim=True)
294
+ smoothed = (weights * nearest_pts).sum(dim=1)
295
+ alpha = 0.3
296
+ return (1 - alpha) * x + alpha * smoothed
297
+
298
+ def get_timepoint_data(self):
299
+ result = {
300
+ 't0': self.coords_t0,
301
+ 'time_labels': self.time_labels
302
+ }
303
+ # intermediate timepoints
304
+ for t in sorted(self.coords_intermediate.keys()):
305
+ result[f't{t}'] = self.coords_intermediate[t]
306
+ final_t = max([0] + list(self.coords_intermediate.keys())) + 1
307
+ for b in range(self.num_branches):
308
+ result[f't{final_t}_{b}'] = self.branch_endpoints[b]
309
+ return result
310
+
311
+ def get_train_intermediate_data(self):
312
+ if hasattr(self, 'train_coords_intermediate'):
313
+ return self.train_coords_intermediate
314
+ else:
315
+ # Fallback to full intermediate data if train split not available
316
+ print("Warning: train_coords_intermediate not found, returning full intermediate data.")
317
+ return self.coords_intermediate
environment.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: branchsbm
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ - defaults
6
+ dependencies:
7
+ - conda-forge::python=3.10
8
+ - conda-forge::openssl
9
+ - ca-certificates
10
+ - certifi
11
+ - pytorch::pytorch
12
+ - matplotlib
13
+ - pandas
14
+ - seaborn
15
+ - torchmetrics
16
+ - numpy>=1.26.0,<2.0.0
17
+ - scikit-learn
18
+ - pyyaml
19
+ - jupyter
20
+ - ipykernel
21
+ - notebook
22
+ - tqdm
23
+ - pytorch-lightning>=2.0.0
24
+ - lightning>=2.0.0
25
+ - python-igraph
26
+ - leidenalg
27
+ - pip
28
+ - pip:
29
+ - scipy==1.13.1
30
+ - wandb==0.22.1
31
+ - torchcfm==1.0.7
32
+ - torchdyn==1.0.6
33
+ - torchdiffeq
34
+ - pot
35
+ - hydra-core
36
+ - omegaconf
37
+ - laspy
38
+ - umap-learn
39
+ - scanpy
40
+ - lpips
41
+ - geomloss
parsers.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ def parse_args():
4
+ parser = argparse.ArgumentParser(description="Train BranchSBM")
5
+
6
+ parser.add_argument("--seed", default=2, type=int)
7
+
8
+ parser.add_argument(
9
+ "--config_path", type=str,
10
+ default='',
11
+ help="Path to config file"
12
+ )
13
+ ####### ITERATES IN THE CODE #######
14
+ parser.add_argument(
15
+ "--seeds",
16
+ nargs="+",
17
+ type=int,
18
+ default=[42, 43, 44, 45, 46],
19
+ help="Random seeds to iterate over",
20
+ )
21
+ parser.add_argument(
22
+ "--t_exclude",
23
+ nargs="+",
24
+ type=int,
25
+ default=None,
26
+ help="Time points to exclude (iterating over)",
27
+ )
28
+ ####################################
29
+
30
+ parser.add_argument(
31
+ "--working_dir",
32
+ type=str,
33
+ default="path/to/your/home/BranchSBM",
34
+ help="Working directory",
35
+ )
36
+ parser.add_argument(
37
+ "--resume_flow_model_ckpt",
38
+ type=str,
39
+ default=None,
40
+ help="Path to the flow model to resume training",
41
+ )
42
+ parser.add_argument(
43
+ "--resume_growth_model_ckpt",
44
+ type=str,
45
+ default=None,
46
+ help="Path to the flow model to resume training",
47
+ )
48
+ parser.add_argument(
49
+ "--load_geopath_model_ckpt",
50
+ type=str,
51
+ default=None,
52
+ help="Path to the geopath model to resume training",
53
+ )
54
+ parser.add_argument(
55
+ "--sequential",
56
+ action=argparse.BooleanOptionalAction,
57
+ default=False,
58
+ help="Use sequential training for multi-timepoint data",
59
+ )
60
+ parser.add_argument(
61
+ "--discard",
62
+ action=argparse.BooleanOptionalAction,
63
+ default=False,
64
+ help="Discard small clusters instead of merging them in Leiden clustering",
65
+ )
66
+ parser.add_argument(
67
+ "--pseudo",
68
+ action=argparse.BooleanOptionalAction,
69
+ default=False,
70
+ help="Use pseudotime-based clustering for Weinreb data instead of Leiden on t=2",
71
+ )
72
+ parser.add_argument(
73
+ "--branches",
74
+ type=int,
75
+ default=2,
76
+ help="Number of branches",
77
+ )
78
+ parser.add_argument(
79
+ "--metric_clusters",
80
+ type=int,
81
+ default=3,
82
+ help="Number of metric clusters",
83
+ )
84
+ parser.add_argument(
85
+ "--resolution",
86
+ type=float,
87
+ default=1.0,
88
+ help="Resolution parameter for Leiden clustering",
89
+ )
90
+
91
+ ######### DATASETS #################
92
+ parser = datasets_parser(parser)
93
+ ####################################
94
+
95
+ ######### IMAGE DATASETS ###########
96
+ parser = image_datasets_parser(parser)
97
+ ####################################
98
+
99
+ ######### METRICS ##################
100
+ parser = metric_parser(parser)
101
+ ####################################
102
+
103
+ ######### General Training #########
104
+ parser = general_training_parser(parser)
105
+ ####################################
106
+
107
+ ######### Training GeoPath Network ####
108
+ parser = geopath_network_parser(parser)
109
+ ####################################
110
+
111
+ ######### Training Flow Network ####
112
+ parser = flow_network_parser(parser)
113
+ ####################################
114
+
115
+ parser = growth_network_parser(parser)
116
+
117
+ return parser.parse_args()
118
+
119
+
120
+ def datasets_parser(parser):
121
+ parser.add_argument("--dim", type=int, default=3, help="Dimension of data")
122
+
123
+ parser.add_argument(
124
+ "--data_type",
125
+ type=str,
126
+ default="lidar",
127
+ help="Type of data, now wither scrna or one of toys",
128
+ )
129
+ parser.add_argument(
130
+ "--data_path",
131
+ type=str,
132
+ default="",
133
+ help="lidar data path",
134
+ )
135
+ parser.add_argument(
136
+ "--data_name",
137
+ type=str,
138
+ default="lidar",
139
+ help="Path to the dataset",
140
+ )
141
+ parser.add_argument(
142
+ "--whiten",
143
+ action=argparse.BooleanOptionalAction,
144
+ default=True,
145
+ help="Whiten the data",
146
+ )
147
+ parser.add_argument(
148
+ "--min_cells",
149
+ type=int,
150
+ default=500,
151
+ help="Minimum cells per cluster for Leiden clustering",
152
+ )
153
+ parser.add_argument(
154
+ "--k",
155
+ type=int,
156
+ default=20,
157
+ help="Number of neighbors for KNN graph in Leiden clustering",
158
+ )
159
+ parser.add_argument(
160
+ "--pseudotime_threshold",
161
+ type=float,
162
+ default=0.6,
163
+ help="Pseudotime threshold for terminal cells (only used when --pseudo is True)",
164
+ )
165
+ parser.add_argument(
166
+ "--terminal_neighbors",
167
+ type=int,
168
+ default=20,
169
+ help="Number of neighbors for terminal cell clustering (only used when --pseudo is True)",
170
+ )
171
+ parser.add_argument(
172
+ "--terminal_resolution",
173
+ type=float,
174
+ default=0.2,
175
+ help="Resolution for terminal cell Leiden clustering (only used when --pseudo is True)",
176
+ )
177
+ parser.add_argument(
178
+ "--n_dcs",
179
+ type=int,
180
+ default=10,
181
+ help="Number of diffusion components for DPT (only used when --pseudo is True)",
182
+ )
183
+ parser.add_argument(
184
+ "--initial_neighbors",
185
+ type=int,
186
+ default=30,
187
+ help="Number of neighbors for initial kNN graph (only used when --pseudo is True)",
188
+ )
189
+ parser.add_argument(
190
+ "--initial_resolution",
191
+ type=float,
192
+ default=1.0,
193
+ help="Resolution for initial Leiden clustering (only used when --pseudo is True)",
194
+ )
195
+ return parser
196
+
197
+
198
+ def image_datasets_parser(parser):
199
+ parser.add_argument(
200
+ "--image_size",
201
+ type=int,
202
+ default=128,
203
+ help="Size of the image",
204
+ )
205
+ parser.add_argument(
206
+ "--x0_label",
207
+ type=str,
208
+ default="dog",
209
+ help="Label for x0",
210
+ )
211
+ parser.add_argument(
212
+ "--x1_label",
213
+ type=str,
214
+ default="cat",
215
+ help="Label for x1",
216
+ )
217
+ return parser
218
+
219
+
220
+ def metric_parser(parser):
221
+ parser.add_argument(
222
+ "--branchsbm",
223
+ action=argparse.BooleanOptionalAction,
224
+ default=True,
225
+ help="If branched SBM",
226
+ )
227
+ parser.add_argument(
228
+ "--n_centers",
229
+ type=int,
230
+ default=100,
231
+ help="Number of centers for RBF network",
232
+ )
233
+ parser.add_argument(
234
+ "--kappa",
235
+ type=float,
236
+ default=1.0,
237
+ help="Kappa parameter for RBF network",
238
+ )
239
+ parser.add_argument(
240
+ "--rho",
241
+ type=float,
242
+ default=0.001,
243
+ help="Rho parameter in Riemanian Velocity Calculation",
244
+ )
245
+ parser.add_argument(
246
+ "--velocity_metric",
247
+ type=str,
248
+ default="rbf",
249
+ help="Metric for velocity calculation",
250
+ )
251
+ parser.add_argument(
252
+ "--gammas",
253
+ nargs="+",
254
+ type=float,
255
+ default=[0.2, 0.2],
256
+ help="Gamma parameter in Riemanian Velocity Calculation",
257
+ )
258
+
259
+ parser.add_argument(
260
+ "--metric_epochs",
261
+ type=int,
262
+ default=100,
263
+ help="Number of epochs for metric learning",
264
+ )
265
+ parser.add_argument(
266
+ "--metric_patience",
267
+ type=int,
268
+ default=20,
269
+ help="Patience for metric learning",
270
+ )
271
+ parser.add_argument(
272
+ "--metric_lr",
273
+ type=float,
274
+ default=1e-2,
275
+ help="Learning rate for metric learning",
276
+ )
277
+ parser.add_argument(
278
+ "--alpha_metric",
279
+ type=float,
280
+ default=1.0,
281
+ help="Alpha parameter for metric learning",
282
+ )
283
+
284
+ return parser
285
+
286
+
287
+ def general_training_parser(parser):
288
+ parser.add_argument(
289
+ "--batch_size", type=int, default=128, help="Batch size for training"
290
+ )
291
+ parser.add_argument(
292
+ "--optimal_transport_method",
293
+ type=str,
294
+ default="exact",
295
+ help="Use optimal transport in CFM training",
296
+ )
297
+ parser.add_argument(
298
+ "--ema_decay",
299
+ type=float,
300
+ default=None,
301
+ help="Decay for EMA",
302
+ )
303
+ parser.add_argument(
304
+ "--split_ratios",
305
+ nargs=2,
306
+ type=float,
307
+ default=[0.9, 0.1],
308
+ help="Split ratios for training/validation data in CFM training",
309
+ )
310
+ parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
311
+ parser.add_argument(
312
+ "--accelerator", type=str, default="gpu", help="Training accelerator"
313
+ )
314
+ parser.add_argument(
315
+ "--run_name", type=str, default=None, help="Name for the wandb run"
316
+ )
317
+ parser.add_argument(
318
+ "--sim_num_steps",
319
+ type=int,
320
+ default=1000,
321
+ help="Number of steps in simulation",
322
+ )
323
+ return parser
324
+
325
+
326
+ def geopath_network_parser(parser):
327
+ parser.add_argument(
328
+ "--manifold",
329
+ action=argparse.BooleanOptionalAction,
330
+ default=True,
331
+ help="If use data manifold metric",
332
+ )
333
+ parser.add_argument(
334
+ "--patience_geopath",
335
+ type=int,
336
+ default=50,
337
+ help="Patience for training geopath model",
338
+ )
339
+ parser.add_argument(
340
+ "--hidden_dims_geopath",
341
+ nargs="+",
342
+ type=int,
343
+ default=[64, 64, 64],
344
+ help="Dimensions of hidden layers for GeoPath model training",
345
+ )
346
+ parser.add_argument(
347
+ "--time_geopath",
348
+ action=argparse.BooleanOptionalAction,
349
+ default=False,
350
+ help="Use time in GeoPath model",
351
+ )
352
+ parser.add_argument(
353
+ "--activation_geopath",
354
+ type=str,
355
+ default="selu",
356
+ help="Activation function for GeoPath",
357
+ )
358
+ parser.add_argument(
359
+ "--geopath_optimizer",
360
+ type=str,
361
+ default="adam",
362
+ help="Optimizer for GeoPath training",
363
+ )
364
+ parser.add_argument(
365
+ "--geopath_lr",
366
+ type=float,
367
+ default=1e-4,
368
+ help="Learning rate for GeoPath training",
369
+ )
370
+ parser.add_argument(
371
+ "--geopath_weight_decay",
372
+ type=float,
373
+ default=1e-5,
374
+ help="Weight decay for GeoPath training",
375
+ )
376
+ parser.add_argument(
377
+ "--mmd_weight",
378
+ type=float,
379
+ default=0.1,
380
+ help="Weight for MMD loss at intermediate timepoints (only used when >2 timepoints)",
381
+ )
382
+ return parser
383
+
384
+
385
+ def flow_network_parser(parser):
386
+ parser.add_argument(
387
+ "--sigma", type=float, default=0.1, help="Sigma parameter for CFM (variance)"
388
+ )
389
+ parser.add_argument(
390
+ "--patience",
391
+ type=int,
392
+ default=5,
393
+ help="Patience for early stopping in CFM training",
394
+ )
395
+ parser.add_argument(
396
+ "--hidden_dims_flow",
397
+ nargs="+",
398
+ type=int,
399
+ default=[64, 64, 64],
400
+ help="Dimensions of hidden layers for CFM training",
401
+ )
402
+ parser.add_argument(
403
+ "--check_val_every_n_epoch",
404
+ type=int,
405
+ default=10,
406
+ help="Check validation every N epochs during CFM training",
407
+ )
408
+ parser.add_argument(
409
+ "--activation_flow",
410
+ type=str,
411
+ default="selu",
412
+ help="Activation function for CFM",
413
+ )
414
+ parser.add_argument(
415
+ "--flow_optimizer",
416
+ type=str,
417
+ default="adamw",
418
+ help="Optimizer for GeoPath training",
419
+ )
420
+ parser.add_argument(
421
+ "--flow_lr",
422
+ type=float,
423
+ default=1e-3,
424
+ help="Learning rate for GeoPath training",
425
+ )
426
+ parser.add_argument(
427
+ "--flow_weight_decay",
428
+ type=float,
429
+ default=1e-5,
430
+ help="Weight decay for GeoPath training",
431
+ )
432
+ return parser
433
+
434
+ def growth_network_parser(parser):
435
+ parser.add_argument(
436
+ "--patience_growth",
437
+ type=int,
438
+ default=5,
439
+ help="Patience for early stopping in CFM training",
440
+ )
441
+ parser.add_argument(
442
+ "--time_growth",
443
+ action=argparse.BooleanOptionalAction,
444
+ default=False,
445
+ help="Use time in GeoPath model",
446
+ )
447
+ parser.add_argument(
448
+ "--hidden_dims_growth",
449
+ nargs="+",
450
+ type=int,
451
+ default=[64, 64, 64],
452
+ help="Dimensions of hidden layers for growth net training",
453
+ )
454
+ parser.add_argument(
455
+ "--activation_growth",
456
+ type=str,
457
+ default="tanh",
458
+ help="Activation function for CFM",
459
+ )
460
+ parser.add_argument(
461
+ "--growth_optimizer",
462
+ type=str,
463
+ default="adamw",
464
+ help="Optimizer for GeoPath training",
465
+ )
466
+ parser.add_argument(
467
+ "--growth_lr",
468
+ type=float,
469
+ default=1e-3,
470
+ help="Learning rate for GeoPath training",
471
+ )
472
+ parser.add_argument(
473
+ "--growth_weight_decay",
474
+ type=float,
475
+ default=1e-5,
476
+ help="Weight decay for GeoPath training",
477
+ )
478
+ parser.add_argument(
479
+ "--lambda_energy",
480
+ type=float,
481
+ default=1.0,
482
+ help="Weight for energy loss",
483
+ )
484
+ parser.add_argument(
485
+ "--lambda_mass",
486
+ type=float,
487
+ default=100.0,
488
+ help="Weight for mass loss",
489
+ )
490
+ parser.add_argument(
491
+ "--lambda_match",
492
+ type=float,
493
+ default=1000.0,
494
+ help="Weight for matching loss",
495
+ )
496
+ parser.add_argument(
497
+ "--lambda_recons",
498
+ type=float,
499
+ default=1.0,
500
+ help="Weight for reconstruction loss",
501
+ )
502
+ return parser
scripts/README.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Running Experiments with BranchSBM 🌳🧬
2
+
3
+ This directory contains training scripts for all experiments with BranchSBM, including LiDAR navigation 🗻, simulating cell differentiation 🧫, and cell state perturbation modelling 🧬. This codebase contains code from the [Metric Flow Matching repo](https://github.com/kkapusniak/metric-flow-matching) ([Kapusniak et al. 2024](https://arxiv.org/abs/2405.14780)).
4
+
5
+ ## Environment Installation
6
+ ```
7
+ conda env create -f environment.yml
8
+
9
+ conda activate branchsbm
10
+ ```
11
+
12
+ ## Data
13
+ LiDAR data is taken from the [Generalized Schrödinger Bridge Matching repo](https://github.com/facebookresearch/generalized-schrodinger-bridge-matching) and Mouse Hematopoesis is taken from the [DeepRUOT repo](https://github.com/zhenyiizhang/DeepRUOT)
14
+
15
+ We use perturbation data from the [Tahoe-100M dataset](https://huggingface.co/datasets/tahoebio/Tahoe-100M) containing control DMSO-treated cell data and perturbed cell data.
16
+
17
+ The raw data contains a total of 60K genes. We select the top 2000 highly variable genes (HVGs) and perform principal component analysis (PCA), to maximally capture the variance in the data via the top principal components (38% in the top-50 PCs). **Our goal is to learn the dynamic trajectories that map control cell clusters to the perturbd cell clusters.**
18
+
19
+ **Specifically, we model the following perturbations**:
20
+
21
+ 1. **Clonidine**: Cell states under 5uM Clonidine perturbation at various PC dimensions (50D, 100D, 150D) with 1 unseen population.
22
+ 2. **Trametinib**: Cell states under 5uM Trametinib perturbation (50D) with 2 unseen populations.
23
+
24
+ All data files are stored in:
25
+ ```
26
+ BranchSBMl/data/
27
+ ├── rainier2-thin.las # LiDAR data
28
+ ├── mouse_hematopoiesis.csv # Mouse Hematopoiesis data
29
+ ├── pca_and_leiden_labels.csv # Clonidine data
30
+ ├── Trametinib_5.0uM_pca_and_leidenumap_labels.csv # Trametinib data
31
+ └── Veres_alltime.csv # Pancreatic β-Cell data
32
+ ```
33
+
34
+ ## Running Experiments
35
+
36
+ All training scripts are located in `BranchSBM/scripts/`. Each script is pre-configured for a specific experiment.
37
+
38
+ The scripts for BranchSBM experiments include:
39
+
40
+ - **`lidar.sh`** - LiDAR trajectory data with 2 branches
41
+ - **`mouse.sh`** - Mouse cell differentiation with 2 branches
42
+ - **`clonidine.sh`** - Clonidine perturbation with 2 branches
43
+ - **`trametinib.sh`** - Trametinib perturbation with 3 branches
44
+ - **`veres.sh`** - Pancreatic beta-cell differentiation with 11 branches
45
+
46
+
47
+ The scripts for the baseline single-branch SBM experiments include:
48
+
49
+ - **`mouse_single.sh`** - Mouse single branch
50
+ - **`clonidine_single.sh`** - Clonidine single branch
51
+ - **`trametinib_single.sh`** - Trametinib single branch
52
+ - **`lidar_single.sh`** - LiDAR single branch
53
+
54
+ **Before running experiments:**
55
+
56
+ 1. Set `HOME_LOC` to the base path where BranchSBM is located and `ENV_PATH` to the directory where your environment is downloaded in the `.sh` files in `scripts/`
57
+ 2. Create a path `BranchSBM/results` where the simulated trajectory figures and metrics will be saved. Also, create `BranchSBM/logs` where the training logs will be saved.
58
+ 3. Activate the conda environment:
59
+ ```
60
+ conda activate branchsbm
61
+ ```
62
+ 4. Login to wandb using `wandb login`
63
+
64
+ **Run experiment using `nohup` with the following commands:**
65
+
66
+ ```
67
+ cd scripts
68
+
69
+ chmod lidar.sh
70
+
71
+ nohup ./lidar.sh > lidar.log 2>&1 &
72
+ ```
73
+
74
+ Evaluation will run automatically after the specified number of rollouts `--num_rollouts` is finished. To see metrics, go to `results/<experiment>/metrics/` or the end of `logs/<experiment>.log`.
75
+
76
+ For Clonidine, `x1_1` indicates the cell cluster that is sampled from for training and `x1_2` is the held-out cell cluster. For Trametinib `x1_1` indicates the cell cluster that is sampled from for training and `x1_2` and `x1_3` are the held-out cell clusters.
77
+
78
+ We report the following metrics for each of the clusters in our paper:
79
+ 1. Maximum Mean Discrepancy (RBF-MMD) of simualted cell cluster with target cell cluster (same cell count).
80
+ 2. 1-Wasserstein and 2-Wasserstein distances against full cell population in the cluster.
81
+
82
+ ## Overview of Outputs
83
+
84
+ **Training outputs are saved to experiment-specific directories:**
85
+
86
+ ```
87
+ BranchSBM/results/
88
+ ├── <DATE>_clonidine50D_branched/
89
+ │ └── figures/ # Figures of simulated
90
+ │ └── metrics.csv # JSON of metrics
91
+ ```
92
+
93
+ **PyTorch Lightning automatically saves model checkpoints to:**
94
+
95
+ ```
96
+ BranchSBM/scripts/lightning_logs/
97
+ ├── <wandb-run-id>/
98
+ │ ├── checkpoints/
99
+ │ │ ├── epoch=N-step=M.ckpt # Checkpoint
100
+ ```
101
+
102
+ **Training logs are saved in:**
103
+ ```
104
+ entangled-cell/logs/
105
+ ├── <DATE>_lidar_single_train.log
106
+ ├── <DATE>_lidar_train.log
107
+ ├── <DATE>_mouse_single_train.log
108
+ ├── <DATE>_mouse_train.log
109
+ ├── <DATE>_clonidine_single_train.log
110
+ ├── <DATE>_clonidine50D_train.log
111
+ ├── <DATE>_clonidine100D_train.log
112
+ ├── <DATE>_clonidine150D_train.log
113
+ ├── <DATE>_trametinib_single_train.log
114
+ ├── <DATE>_trametinib_train.log
115
+ └── <DATE>_veres_train.log
116
+ ```
117
+
118
+ ## Available Experiments
119
+
120
+ ### Branched Experiments (Multi-branch trajectories)
121
+
122
+ These experiments model cell differentiation or perturbation with multiple branches:
123
+
124
+ - **`mouse.sh`** - Mouse cell differentiation with 2 branches (GPU 0)
125
+ - **`trametinib.sh`** - Trametinib perturbation with 3 branches (GPU 1)
126
+ - **`lidar.sh`** - LiDAR trajectory data with 2 branches (GPU 2)
127
+ - **`clonidine.sh`** - Clonidine perturbation with 2 branches (GPU 3)
128
+
129
+ ### Single-Branch Experiments (Control/baseline)
130
+
131
+ These are baseline experiments with single trajectories:
132
+
133
+ - **`mouse_single.sh`** - Mouse single trajectory (GPU 4)
134
+ - **`clonidine_single.sh`** - Clonidine single trajectory (GPU 5)
135
+ - **`trametinib_single.sh`** - Trametinib single trajectory (GPU 6)
136
+ - **`lidar_single.sh`** - LiDAR single trajectory (GPU 7)
137
+
138
+ ## Running Scripts
139
+
140
+ ### Run a single experiment
141
+
142
+ From the `scripts/` directory:
143
+
144
+ ```bash
145
+ cd scripts
146
+ chmod +x mouse.sh
147
+ nohup ./mouse.sh > mouse.log 2>&1 &
148
+ ```
149
+
150
+ ### Run all branched experiments in parallel
151
+
152
+ ```bash
153
+ nohup ./mouse.sh > mouse.log 2>&1 &
154
+ nohup ./trametinib.sh > trametinib.log 2>&1 &
155
+ nohup ./lidar.sh > lidar.log 2>&1 &
156
+ nohup ./clonidine.sh > clonidine.log 2>&1 &
157
+ ```
158
+
159
+ ### Run all single-branch experiments in parallel
160
+
161
+ ```bash
162
+ nohup ./mouse_single.sh > mouse_single.log 2>&1 &
163
+ nohup ./clonidine_single.sh > clonidine_single.log 2>&1 &
164
+ nohup ./trametinib_single.sh > trametinib_single.log 2>&1 &
165
+ nohup ./lidar_single.sh > lidar_single.log 2>&1 &
166
+ ```
167
+
168
+ ### Run all experiments simultaneously
169
+
170
+ Each script is assigned to a different GPU, so you can run all 8 in parallel:
171
+
172
+ ```bash
173
+ nohup ./mouse.sh > mouse.log 2>&1 &
174
+ nohup ./trametinib.sh > trametinib.log 2>&1 &
175
+ nohup ./lidar.sh > lidar.log 2>&1 &
176
+ nohup ./clonidine.sh > clonidine.log 2>&1 &
177
+ nohup ./mouse_single.sh > mouse_single.log 2>&1 &
178
+ nohup ./clonidine_single.sh > clonidine_single.log 2>&1 &
179
+ nohup ./trametinib_single.sh > trametinib_single.log 2>&1 &
180
+ nohup ./lidar_single.sh > lidar_single.log 2>&1 &
181
+ ```
182
+
183
+ ## Monitoring Training
184
+
185
+ Logs are saved in `./BranchSBM/logs/` with format `MM_DD_<experiment>_train.log`.
186
+
187
+ Each experiment logs to wandb with a unique run name:
188
+ - Branched experiments: `<dataset>_branched` (e.g., `mouse_branched`)
189
+ - Single experiments: `<dataset>_single` (e.g., `mouse_single`)
190
+
191
+ Visit your wandb dashboard to view training progress in real-time.
192
+
193
+ ## Training Parameters
194
+
195
+ Default training parameters for each experiment:
196
+
197
+ | Parameter | LiDAR | Mouse Hematopoiesis scRNA | Clonidine (50 PCs) | Clonidine (100 PCs) | Clonidine (150 PCs) | Trametinib | Pancreatic β-Cell |
198
+ |---|---|---|---|---|---|---|---|
199
+ | branches | 2 | 2 | 2 | 2 | 2 | 3 | 11 |
200
+ | data dimension | 3 | 2 | 50 | 100 | 150 | 50 | 30 |
201
+ | batch size | 128 | 128 | 32 | 32 | 32 | 32 | 256 |
202
+ | λ_energy | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
203
+ | λ_mass | 100 | 100 | 100 | 100 | 100 | 100 | 100 |
204
+ | λ_match | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ | 1.0 × 10³ |
205
+ | λ_recons | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
206
+ | λ_growth | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 | 0.01 |
207
+ | V_t | LAND | LAND | RBF | RBF | RBF | RBF | RBF |
208
+ | RBF N_c | - | - | 150 | 300 | 300 | 150 | 300 |
209
+ | RBF κ | - | - | 1.5 | 2.0 | 3.0 | 1.5 | 3.0 |
210
+ | hidden dimension | 64 | 64 | 1024 | 1024 | 1024 | 1024 | 1024 |
211
+ | lr interpolant | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ | 1.0 × 10⁻⁴ |
212
+ | lr velocity | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ |
213
+ | lr growth | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ | 1.0 × 10⁻³ |
214
+
215
+ To modify parameters, edit the corresponding `.sh` file.
216
+
217
+ ## Training Pipeline
218
+
219
+ Each experiment runs through 4 stages:
220
+
221
+ 1. **Stage 1: Geopath** - Train geodesic path interpolants
222
+ 2. **Stage 2: Flow Matching** - Train continuous normalizing flows
223
+ 3. **Stage 3: Growth** - Train growth networks for branches
224
+ 4. **Stage 4: Joint** - Joint training of all components
225
+
226
+ Checkpoints are saved automatically and loaded between stages.
scripts/clonidine100.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='clonidine100D_branched'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=4
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
23
+ --config_path "$SCRIPT_LOC/configs/clonidine_100D.yaml" \
24
+ --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/clonidine150.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='clonidine150D_branched'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=5
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
23
+ --config_path "$SCRIPT_LOC/configs/clonidine_150D.yaml" \
24
+ --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/clonidine50.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='clonidine50D_branched'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=3
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name ${DATE}_${SPECIAL_PREFIX} \
23
+ --config_path "$SCRIPT_LOC/configs/clonidine_50D.yaml" \
24
+ --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/clonidine50_single.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='clonidine_single'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=3
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name "clonidine50D_single" \
23
+ --config_path "$SCRIPT_LOC/configs/clonidine_50Dsingle.yaml" \
24
+ --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/lidar.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='lidar_branched'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=2
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --config_path "$SCRIPT_LOC/configs/lidar.yaml" \
22
+ --epochs 10 \
23
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
24
+ --batch_size 128 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/lidar_single.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='lidar_single'
9
+ # set 3 have skip connection
10
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
11
+
12
+ # Set GPU device
13
+ export CUDA_VISIBLE_DEVICES=2
14
+
15
+ # ===================================================================
16
+ source "$(conda info --base)/etc/profile.d/conda.sh"
17
+ conda activate $ENV_LOC
18
+
19
+ cd $HOME_LOC
20
+
21
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
22
+ --config_path "$SCRIPT_LOC/configs/lidar_single.yaml" \
23
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
24
+ --epochs 100 \
25
+ --batch_size 128 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
26
+
27
+ conda deactivate
scripts/mouse.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='mouse_branched'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=1
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --config_path "$SCRIPT_LOC/configs/mouse.yaml" \
22
+ --epochs 100 \
23
+ --run_name "${DATE}_${SPECIAL_PREFIX}" >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
24
+
25
+ conda deactivate
scripts/mouse_single.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='mouse_single'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=1
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
23
+ --config_path "$SCRIPT_LOC/configs/mouse_single.yaml" >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
24
+
25
+ conda deactivate
scripts/trametinib.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='trametinib_branched'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=6
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
23
+ --config_path "$SCRIPT_LOC/configs/trametinib.yaml" \
24
+ --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/trametinib_single.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='trametinib_single'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=6
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name "${DATE}_${SPECIAL_PREFIX}" \
23
+ --config_path "$SCRIPT_LOC/configs/trametinib_single.yaml" \
24
+ --batch_size 32 >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
scripts/veres.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ HOME_LOC=/path/to/your/home/BranchSBM
4
+ ENV_LOC=/path/to/your/envs/branchsbm
5
+ SCRIPT_LOC=$HOME_LOC
6
+ LOG_LOC=$HOME_LOC/logs
7
+ DATE=$(date +%m_%d)
8
+ SPECIAL_PREFIX='veres'
9
+ PYTHON_EXECUTABLE=$ENV_LOC/bin/python
10
+
11
+ # Set GPU device
12
+ export CUDA_VISIBLE_DEVICES=7
13
+
14
+ # ===================================================================
15
+ source "$(conda info --base)/etc/profile.d/conda.sh"
16
+ conda activate $ENV_LOC
17
+
18
+ cd $HOME_LOC
19
+
20
+ $PYTHON_EXECUTABLE $SCRIPT_LOC/train.py \
21
+ --epochs 100 \
22
+ --run_name ${DATE}_${SPECIAL_PREFIX} \
23
+ --min_cells 100 \
24
+ --config $SCRIPT_LOC/configs/veres.yaml >> ${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}_train.log 2>&1
25
+
26
+ conda deactivate
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/branch_flow_net_test.py ADDED
@@ -0,0 +1,1791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Separate test classes for each BranchSBM experiment with specific plotting styles.
3
+ Each class handles testing and visualization for: LiDAR, Mouse, Clonidine, Trametinib, Veres.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import csv
9
+ import torch
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import pytorch_lightning as pl
13
+ import random
14
+ import ot
15
+ from torchdyn.core import NeuralODE
16
+ from matplotlib.colors import LinearSegmentedColormap
17
+ from matplotlib.collections import LineCollection
18
+ from .networks.utils import flow_model_torch_wrapper
19
+ from .branch_flow_net_train import BranchFlowNetTrainBase
20
+ from .branch_growth_net_train import GrowthNetTrain
21
+ from .utils import wasserstein, mix_rbf_mmd2, plot_lidar
22
+ import json
23
+
24
+ def evaluate_model(gt_data, model_data, a, b):
25
+ # ensure inputs are tensors
26
+ if not isinstance(gt_data, torch.Tensor):
27
+ gt_data = torch.tensor(gt_data, dtype=torch.float32)
28
+ if not isinstance(model_data, torch.Tensor):
29
+ model_data = torch.tensor(model_data, dtype=torch.float32)
30
+
31
+ # choose device: prefer model_data's device if it's not CPU, otherwise use gt_data's device
32
+ try:
33
+ model_dev = model_data.device
34
+ except Exception:
35
+ model_dev = torch.device('cpu')
36
+ try:
37
+ gt_dev = gt_data.device
38
+ except Exception:
39
+ gt_dev = torch.device('cpu')
40
+
41
+ device = model_dev if model_dev.type != 'cpu' else gt_dev
42
+
43
+ gt = gt_data.to(device=device, dtype=torch.float32)
44
+ md = model_data.to(device=device, dtype=torch.float32)
45
+
46
+ M = torch.cdist(gt, md, p=2).cpu().numpy()
47
+ if np.isnan(M).any() or np.isinf(M).any():
48
+ return np.nan
49
+ return ot.emd2(a, b, M, numItermax=1e7)
50
+
51
+ def compute_distribution_distances(pred, true, pred_full=None, true_full=None):
52
+ w1 = wasserstein(pred, true, power=1)
53
+ w2 = wasserstein(pred, true, power=2)
54
+
55
+ # Use full dimensions for MMD if provided, otherwise use same as W1/W2
56
+ mmd_pred = pred_full if pred_full is not None else pred
57
+ mmd_true = true_full if true_full is not None else true
58
+
59
+ # MMD requires same number of samples — randomly subsample the larger set
60
+ n_pred, n_true = mmd_pred.shape[0], mmd_true.shape[0]
61
+ if n_pred > n_true:
62
+ perm = torch.randperm(n_pred)[:n_true]
63
+ mmd_pred = mmd_pred[perm]
64
+ elif n_true > n_pred:
65
+ perm = torch.randperm(n_true)[:n_pred]
66
+ mmd_true = mmd_true[perm]
67
+ mmd = mix_rbf_mmd2(mmd_pred, mmd_true, sigma_list=[0.01, 0.1, 1, 10, 100]).item()
68
+
69
+ return {"W1": w1, "W2": w2, "MMD": mmd}
70
+
71
+
72
+ def compute_tmv_from_mass_over_time(mass_over_time, all_endpoints, time_points=None, timepoint_data=None, time_index=None, target_time=None, gt_key_template='t1_{}', weights_over_time=None):
73
+
74
+ if weights_over_time is not None or mass_over_time is not None:
75
+ if time_index is None:
76
+ if target_time is not None and time_points is not None:
77
+ arr = np.array(time_points)
78
+ time_index = int(np.argmin(np.abs(arr - float(target_time))))
79
+ else:
80
+ # default to last index
81
+ ref_list = weights_over_time if weights_over_time is not None else mass_over_time
82
+ time_index = len(ref_list[0]) - 1
83
+ else:
84
+ # neither available; time_index not used
85
+ if time_index is None:
86
+ time_index = -1
87
+
88
+ n_branches = len(all_endpoints)
89
+
90
+ # initial total cells for normalization
91
+ n_initial = None
92
+ if timepoint_data is not None and 't0' in timepoint_data:
93
+ try:
94
+ n_initial = int(timepoint_data['t0'].shape[0])
95
+ except Exception:
96
+ n_initial = None
97
+
98
+ pred_masses = []
99
+ for i in range(n_branches):
100
+ # Use sum of actual particle weights if available, otherwise mean_weight * num_particles
101
+ if weights_over_time is not None:
102
+ try:
103
+ weights_tensor = weights_over_time[i][time_index]
104
+ # Sum all particle weights to get total mass for this branch
105
+ total_mass = float(weights_tensor.sum().item())
106
+ pred_masses.append(total_mass)
107
+ continue
108
+ except Exception:
109
+ pass # Fall through to mean weight calculation
110
+
111
+ # Fallback: mean weight from mass_over_time if available, otherwise assume weight=1
112
+ mean_w = 1.0
113
+ if mass_over_time is not None:
114
+ try:
115
+ mean_w = float(mass_over_time[i][time_index])
116
+ except Exception:
117
+ mean_w = 1.0
118
+
119
+ # determine number of particles for this branch
120
+ num_particles = 0
121
+ try:
122
+ if hasattr(all_endpoints[i], 'shape'):
123
+ num_particles = int(all_endpoints[i].shape[0])
124
+ else:
125
+ num_particles = int(len(all_endpoints[i]))
126
+ except Exception:
127
+ num_particles = 0
128
+
129
+ pred_masses.append(mean_w * float(num_particles))
130
+
131
+ # ground-truth masses per branch
132
+ gt_masses = []
133
+ if timepoint_data is not None:
134
+ for i in range(n_branches):
135
+ key1 = gt_key_template.format(i)
136
+ if key1 in timepoint_data:
137
+ gt_masses.append(float(timepoint_data[key1].shape[0]))
138
+ else:
139
+ base_key = gt_key_template.split("_")[0] if '_' in gt_key_template else gt_key_template
140
+ if base_key in timepoint_data:
141
+ gt_masses.append(float(timepoint_data[base_key].shape[0]))
142
+ else:
143
+ gt_masses.append(0.0)
144
+ else:
145
+ gt_masses = [0.0 for _ in range(n_branches)]
146
+
147
+ # determine normalization denominator
148
+ if n_initial is None:
149
+ s = float(sum(gt_masses))
150
+ if s > 0:
151
+ n_initial = s
152
+ else:
153
+ n_initial = float(sum(pred_masses)) if sum(pred_masses) > 0 else 1.0
154
+
155
+ pred_fracs = [m / float(n_initial) for m in pred_masses]
156
+ gt_fracs = [m / float(n_initial) for m in gt_masses]
157
+
158
+ tmv = 0.5 * float(np.sum(np.abs(np.array(pred_fracs) - np.array(gt_fracs))))
159
+
160
+ return {
161
+ 'time_index': time_index,
162
+ 'pred_masses': pred_masses,
163
+ 'gt_masses': gt_masses,
164
+ 'pred_fracs': pred_fracs,
165
+ 'gt_fracs': gt_fracs,
166
+ 'tmv': tmv,
167
+ }
168
+
169
+
170
+ class FlowNetTestLidar(GrowthNetTrain):
171
+
172
+ def test_step(self, batch, batch_idx):
173
+ # Unwrap CombinedLoader outer tuple if needed
174
+ if isinstance(batch, (list, tuple)) and len(batch) == 1:
175
+ batch = batch[0]
176
+
177
+ if isinstance(batch, dict) and "test_samples" in batch:
178
+ test_samples = batch["test_samples"]
179
+ metric_samples = batch["metric_samples"]
180
+
181
+ if isinstance(test_samples, (list, tuple)) and len(test_samples) >= 2 and isinstance(test_samples[-1], int):
182
+ test_samples = test_samples[0]
183
+ if isinstance(metric_samples, (list, tuple)) and len(metric_samples) >= 2 and isinstance(metric_samples[-1], int):
184
+ metric_samples = metric_samples[0]
185
+
186
+ if isinstance(test_samples, (list, tuple)) and len(test_samples) == 1:
187
+ test_samples = test_samples[0]
188
+ main_batch = test_samples
189
+
190
+ if isinstance(metric_samples, dict):
191
+ metric_batch = list(metric_samples.values())
192
+ elif isinstance(metric_samples, (list, tuple)):
193
+ metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples]
194
+ else:
195
+ metric_batch = [metric_samples]
196
+ elif isinstance(batch, (list, tuple)) and len(batch) == 2:
197
+ # Old tuple format: (test_samples, metric_samples)
198
+ # Each could be dict or list
199
+ test_samples = batch[0]
200
+ metric_samples = batch[1]
201
+
202
+ if isinstance(test_samples, dict):
203
+ main_batch = test_samples
204
+ elif isinstance(test_samples, (list, tuple)):
205
+ main_batch = test_samples[0]
206
+ else:
207
+ main_batch = test_samples
208
+
209
+ if isinstance(metric_samples, dict):
210
+ metric_batch = list(metric_samples.values())
211
+ elif isinstance(metric_samples, (list, tuple)):
212
+ metric_batch = [m[0] if isinstance(m, (list, tuple)) and len(m) == 1 else m for m in metric_samples]
213
+ else:
214
+ metric_batch = [metric_samples]
215
+ else:
216
+ # Fallback
217
+ main_batch = batch
218
+ metric_batch = []
219
+
220
+ timepoint_data = self.trainer.datamodule.get_timepoint_data()
221
+ # main_batch is a dict like {"x0": (tensor, weights), ...}
222
+ if isinstance(main_batch, dict):
223
+ device = main_batch["x0"][0].device
224
+ else:
225
+ device = main_batch[0]["x0"][0].device
226
+
227
+ x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
228
+ w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device)
229
+ full_batch = {"x0": (x0_all, w0_all)}
230
+
231
+ time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch)
232
+
233
+ cloud_points = main_batch["dataset"][0] # [N, 3]
234
+
235
+ # Run 5 trials with random subsampling for robust metrics
236
+ n_trials = 5
237
+
238
+ # Compute per-branch metrics
239
+ metrics_dict = {}
240
+ for i, endpoints in enumerate(all_endpoints):
241
+ true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
242
+ true_data = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(endpoints.device)
243
+
244
+ w1_br, w2_br, mmd_br = [], [], []
245
+ for trial in range(n_trials):
246
+ n_min = min(endpoints.shape[0], true_data.shape[0])
247
+ perm_pred = torch.randperm(endpoints.shape[0])[:n_min]
248
+ perm_gt = torch.randperm(true_data.shape[0])[:n_min]
249
+ m = compute_distribution_distances(
250
+ endpoints[perm_pred, :2], true_data[perm_gt, :2],
251
+ pred_full=endpoints[perm_pred], true_full=true_data[perm_gt]
252
+ )
253
+ w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
254
+
255
+ metrics_dict[f"branch_{i+1}"] = {
256
+ "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
257
+ "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
258
+ "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
259
+ }
260
+ self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
261
+ print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
262
+ f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
263
+ f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
264
+
265
+ # Compute combined metrics across all branches (5 trials)
266
+ all_pred_combined = torch.cat(list(all_endpoints), dim=0)
267
+ all_true_list = []
268
+ for i in range(len(all_endpoints)):
269
+ true_data_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
270
+ all_true_list.append(torch.tensor(timepoint_data[true_data_key], dtype=torch.float32).to(all_pred_combined.device))
271
+ all_true_combined = torch.cat(all_true_list, dim=0)
272
+
273
+ w1_trials, w2_trials, mmd_trials = [], [], []
274
+ for trial in range(n_trials):
275
+ n_min = min(all_pred_combined.shape[0], all_true_combined.shape[0])
276
+ perm_pred = torch.randperm(all_pred_combined.shape[0])[:n_min]
277
+ perm_gt = torch.randperm(all_true_combined.shape[0])[:n_min]
278
+ m = compute_distribution_distances(
279
+ all_pred_combined[perm_pred, :2], all_true_combined[perm_gt, :2],
280
+ pred_full=all_pred_combined[perm_pred], true_full=all_true_combined[perm_gt]
281
+ )
282
+ w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
283
+
284
+ w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
285
+ w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
286
+ mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
287
+ self.log("test/W1_combined", w1_mean, on_epoch=True)
288
+ self.log("test/W2_combined", w2_mean, on_epoch=True)
289
+ self.log("test/MMD_combined", mmd_mean, on_epoch=True)
290
+
291
+ metrics_dict["combined"] = {
292
+ "W1_mean": float(w1_mean), "W1_std": float(w1_std),
293
+ "W2_mean": float(w2_mean), "W2_std": float(w2_std),
294
+ "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
295
+ "n_trials": n_trials,
296
+ }
297
+ print(f"\n=== Combined ===")
298
+ print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
299
+ print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
300
+ print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
301
+
302
+ # Inverse-transform cloud points for visualization
303
+ if self.whiten:
304
+ cloud_points = torch.tensor(
305
+ self.trainer.datamodule.scaler.inverse_transform(
306
+ cloud_points.cpu().detach().numpy()
307
+ )
308
+ )
309
+
310
+ # Create results directory structure
311
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
312
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
313
+ figures_dir = f'{results_dir}/figures'
314
+ os.makedirs(figures_dir, exist_ok=True)
315
+
316
+ # Save metrics to JSON
317
+ metrics_path = f'{results_dir}/metrics.json'
318
+ with open(metrics_path, 'w') as f:
319
+ json.dump(metrics_dict, f, indent=2)
320
+ print(f"Metrics saved to {metrics_path}")
321
+
322
+ # Save detailed per-branch metrics to CSV
323
+ detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
324
+ with open(detailed_csv_path, 'w', newline='') as csvfile:
325
+ writer = csv.writer(csvfile)
326
+ writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
327
+ for key in sorted(metrics_dict.keys()):
328
+ m = metrics_dict[key]
329
+ writer.writerow([key,
330
+ f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
331
+ f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
332
+ f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
333
+ print(f"Detailed metrics CSV saved to {detailed_csv_path}")
334
+
335
+ # Convert all_trajs from list of lists to stacked tensors for plotting
336
+ # all_trajs[i] is a list of T tensors of shape [B, D]
337
+ # Stack to get shape [B, T, D]
338
+ stacked_trajs = []
339
+ for traj_list in all_trajs:
340
+ # Stack along time dimension (dim=1) to get [B, T, D]
341
+ stacked_traj = torch.stack(traj_list, dim=1)
342
+ stacked_trajs.append(stacked_traj)
343
+
344
+ # Inverse-transform trajectories to match cloud_points coordinates
345
+ if self.whiten:
346
+ stacked_trajs_original = []
347
+ for traj in stacked_trajs:
348
+ B, T, D = traj.shape
349
+ # Reshape to [B*T, D] for inverse transform
350
+ traj_flat = traj.reshape(-1, D).cpu().detach().numpy()
351
+ traj_inv = self.trainer.datamodule.scaler.inverse_transform(traj_flat)
352
+ # Reshape back to [B, T, D]
353
+ traj_inv = torch.tensor(traj_inv).reshape(B, T, D)
354
+ stacked_trajs_original.append(traj_inv)
355
+ stacked_trajs = stacked_trajs_original
356
+
357
+ # ===== Plot all branches together =====
358
+ fig = plt.figure(figsize=(10, 8))
359
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
360
+ ax.view_init(elev=30, azim=-115, roll=0)
361
+ for i, traj in enumerate(stacked_trajs):
362
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
363
+ plt.savefig(f'{figures_dir}/{self.args.data_name}_all_branches.png', dpi=300)
364
+ plt.close()
365
+
366
+ # ===== Plot each branch separately =====
367
+ for i, traj in enumerate(stacked_trajs):
368
+ fig = plt.figure(figsize=(10, 8))
369
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
370
+ ax.view_init(elev=30, azim=-115, roll=0)
371
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
372
+ plt.savefig(f'{figures_dir}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
373
+ plt.close()
374
+
375
+ print(f"LiDAR figures saved to {figures_dir}")
376
+
377
+
378
+ class FlowNetTestMouse(GrowthNetTrain):
379
+
380
+ def test_step(self, batch, batch_idx):
381
+ # Handle both tuple and dict batch formats from CombinedLoader
382
+ if isinstance(batch, dict):
383
+ main_batch = batch.get("test_samples", batch)
384
+ if isinstance(main_batch, tuple):
385
+ main_batch = main_batch[0]
386
+ elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
387
+ if isinstance(batch[0], dict):
388
+ main_batch = batch[0].get("test_samples", batch[0])
389
+ if isinstance(main_batch, tuple):
390
+ main_batch = main_batch[0]
391
+ else:
392
+ main_batch = batch[0][0]
393
+ else:
394
+ main_batch = batch
395
+
396
+ device = main_batch["x0"][0].device
397
+
398
+ # Use val x0 as initial conditions
399
+ x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
400
+
401
+ # Get timepoint data for ground truth
402
+ timepoint_data = self.trainer.datamodule.get_timepoint_data()
403
+
404
+ # Ground truth at t1 (intermediate timepoint)
405
+ data_t1 = torch.tensor(timepoint_data['t1'], dtype=torch.float32)
406
+
407
+ # Define color schemes for mouse (2 branches)
408
+ custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
409
+ custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
410
+ custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
411
+ custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
412
+
413
+ t_span_full = torch.linspace(0, 1.0, 100).to(device)
414
+ all_trajs = []
415
+
416
+ for i, flow_net in enumerate(self.flow_nets):
417
+ node = NeuralODE(
418
+ flow_model_torch_wrapper(flow_net),
419
+ solver="euler",
420
+ sensitivity="adjoint",
421
+ ).to(device)
422
+
423
+ with torch.no_grad():
424
+ traj = node.trajectory(x0, t_span_full).cpu() # [T, B, D]
425
+
426
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
427
+ all_trajs.append(traj)
428
+
429
+ t_span_metric_t1 = torch.linspace(0, 0.5, 50).to(device)
430
+ t_span_metric_t2 = torch.linspace(0, 1.0, 100).to(device)
431
+ n_trials = 5
432
+
433
+ # Gather t2 branch ground truth
434
+ data_t2_branches = []
435
+ for i in range(len(self.flow_nets)):
436
+ key = f't2_{i+1}'
437
+ if key in timepoint_data:
438
+ data_t2_branches.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
439
+ elif i == 0 and 't2' in timepoint_data:
440
+ data_t2_branches.append(torch.tensor(timepoint_data['t2'], dtype=torch.float32))
441
+ else:
442
+ data_t2_branches.append(None)
443
+
444
+ # Combined t2 ground truth (all branches merged)
445
+ data_t2_all_list = [d for d in data_t2_branches if d is not None]
446
+ data_t2_combined = torch.cat(data_t2_all_list, dim=0) if data_t2_all_list else None
447
+
448
+ # ---- t1 combined metrics (all branches pooled, compared to t1) ----
449
+ w1_t1_trials, w2_t1_trials, mmd_t1_trials = [], [], []
450
+
451
+ for trial in range(n_trials):
452
+ all_preds = []
453
+ for i, flow_net in enumerate(self.flow_nets):
454
+ node = NeuralODE(
455
+ flow_model_torch_wrapper(flow_net),
456
+ solver="euler",
457
+ sensitivity="adjoint",
458
+ ).to(device)
459
+
460
+ with torch.no_grad():
461
+ traj = node.trajectory(x0, t_span_metric_t1) # [T, B, D]
462
+
463
+ x_final = traj[-1].cpu() # [B, D]
464
+ all_preds.append(x_final)
465
+
466
+ preds = torch.cat(all_preds, dim=0)
467
+ target_size = preds.shape[0]
468
+ perm = torch.randperm(data_t1.shape[0])[:target_size]
469
+ data_t1_reduced = data_t1[perm]
470
+
471
+ metrics = compute_distribution_distances(
472
+ preds[:, :2], data_t1_reduced[:, :2]
473
+ )
474
+ w1_t1_trials.append(metrics["W1"])
475
+ w2_t1_trials.append(metrics["W2"])
476
+ mmd_t1_trials.append(metrics["MMD"])
477
+
478
+ # ---- t2 per-branch metrics (each branch endpoint vs its own t2 cluster) ----
479
+ branch_t2_metrics = {}
480
+ for i, flow_net in enumerate(self.flow_nets):
481
+ if data_t2_branches[i] is None:
482
+ continue
483
+ w1_br, w2_br, mmd_br = [], [], []
484
+ for trial in range(n_trials):
485
+ node = NeuralODE(
486
+ flow_model_torch_wrapper(flow_net),
487
+ solver="euler",
488
+ sensitivity="adjoint",
489
+ ).to(device)
490
+ with torch.no_grad():
491
+ traj = node.trajectory(x0, t_span_metric_t2)
492
+ x_final = traj[-1].cpu()
493
+ gt = data_t2_branches[i]
494
+ n_min = min(x_final.shape[0], gt.shape[0])
495
+ perm_pred = torch.randperm(x_final.shape[0])[:n_min]
496
+ perm_gt = torch.randperm(gt.shape[0])[:n_min]
497
+ m = compute_distribution_distances(
498
+ x_final[perm_pred, :2], gt[perm_gt, :2]
499
+ )
500
+ w1_br.append(m["W1"])
501
+ w2_br.append(m["W2"])
502
+ mmd_br.append(m["MMD"])
503
+ branch_t2_metrics[f"branch_{i+1}_t2"] = {
504
+ "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
505
+ "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
506
+ "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
507
+ }
508
+ print(f"Branch {i+1} @ t2 — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
509
+ f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
510
+ f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
511
+
512
+ # ---- t2 combined metrics (all branches pooled, compared to all t2) ----
513
+ w1_t2_trials, w2_t2_trials, mmd_t2_trials = [], [], []
514
+ if data_t2_combined is not None:
515
+ for trial in range(n_trials):
516
+ all_preds = []
517
+ for i, flow_net in enumerate(self.flow_nets):
518
+ node = NeuralODE(
519
+ flow_model_torch_wrapper(flow_net),
520
+ solver="euler",
521
+ sensitivity="adjoint",
522
+ ).to(device)
523
+ with torch.no_grad():
524
+ traj = node.trajectory(x0, t_span_metric_t2)
525
+ all_preds.append(traj[-1].cpu())
526
+ preds = torch.cat(all_preds, dim=0)
527
+ n_min = min(preds.shape[0], data_t2_combined.shape[0])
528
+ perm_pred = torch.randperm(preds.shape[0])[:n_min]
529
+ perm_gt = torch.randperm(data_t2_combined.shape[0])[:n_min]
530
+ m = compute_distribution_distances(
531
+ preds[perm_pred, :2], data_t2_combined[perm_gt, :2]
532
+ )
533
+ w1_t2_trials.append(m["W1"])
534
+ w2_t2_trials.append(m["W2"])
535
+ mmd_t2_trials.append(m["MMD"])
536
+
537
+ # Compute mean and std
538
+ w1_t1_mean, w1_t1_std = np.mean(w1_t1_trials), np.std(w1_t1_trials, ddof=1)
539
+ w2_t1_mean, w2_t1_std = np.mean(w2_t1_trials), np.std(w2_t1_trials, ddof=1)
540
+ mmd_t1_mean, mmd_t1_std = np.mean(mmd_t1_trials), np.std(mmd_t1_trials, ddof=1)
541
+
542
+ # Log metrics
543
+ self.log("test/W1_combined_t1", w1_t1_mean, on_epoch=True)
544
+ self.log("test/W2_combined_t1", w2_t1_mean, on_epoch=True)
545
+ self.log("test/MMD_combined_t1", mmd_t1_mean, on_epoch=True)
546
+
547
+ metrics_dict = {
548
+ "combined_t1": {
549
+ "W1_mean": float(w1_t1_mean), "W1_std": float(w1_t1_std),
550
+ "W2_mean": float(w2_t1_mean), "W2_std": float(w2_t1_std),
551
+ "MMD_mean": float(mmd_t1_mean), "MMD_std": float(mmd_t1_std),
552
+ "n_trials": n_trials,
553
+ }
554
+ }
555
+ metrics_dict.update(branch_t2_metrics)
556
+
557
+ if w1_t2_trials:
558
+ w1_t2_mean, w1_t2_std = np.mean(w1_t2_trials), np.std(w1_t2_trials, ddof=1)
559
+ w2_t2_mean, w2_t2_std = np.mean(w2_t2_trials), np.std(w2_t2_trials, ddof=1)
560
+ mmd_t2_mean, mmd_t2_std = np.mean(mmd_t2_trials), np.std(mmd_t2_trials, ddof=1)
561
+ self.log("test/W1_combined_t2", w1_t2_mean, on_epoch=True)
562
+ self.log("test/W2_combined_t2", w2_t2_mean, on_epoch=True)
563
+ self.log("test/MMD_combined_t2", mmd_t2_mean, on_epoch=True)
564
+ metrics_dict["combined_t2"] = {
565
+ "W1_mean": float(w1_t2_mean), "W1_std": float(w1_t2_std),
566
+ "W2_mean": float(w2_t2_mean), "W2_std": float(w2_t2_std),
567
+ "MMD_mean": float(mmd_t2_mean), "MMD_std": float(mmd_t2_std),
568
+ "n_trials": n_trials,
569
+ }
570
+
571
+ print(f"\n=== Combined @ t1 ===")
572
+ print(f"W1: {w1_t1_mean:.6f} ± {w1_t1_std:.6f}")
573
+ print(f"W2: {w2_t1_mean:.6f} ± {w2_t1_std:.6f}")
574
+ print(f"MMD: {mmd_t1_mean:.6f} ± {mmd_t1_std:.6f}")
575
+ if w1_t2_trials:
576
+ print(f"\n=== Combined @ t2 ===")
577
+ print(f"W1: {w1_t2_mean:.6f} ± {w1_t2_std:.6f}")
578
+ print(f"W2: {w2_t2_mean:.6f} ± {w2_t2_std:.6f}")
579
+ print(f"MMD: {mmd_t2_mean:.6f} ± {mmd_t2_std:.6f}")
580
+
581
+ # Create results directory structure
582
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
583
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
584
+ figures_dir = f'{results_dir}/figures'
585
+ os.makedirs(figures_dir, exist_ok=True)
586
+
587
+ # Save metrics to JSON
588
+ metrics_path = f'{results_dir}/metrics.json'
589
+ with open(metrics_path, 'w') as f:
590
+ json.dump(metrics_dict, f, indent=2)
591
+ print(f"Metrics saved to {metrics_path}")
592
+
593
+
594
+ # Save detailed metrics to CSV
595
+ detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
596
+ with open(detailed_csv_path, 'w', newline='') as csvfile:
597
+ writer = csv.writer(csvfile)
598
+ writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
599
+ for key in sorted(metrics_dict.keys()):
600
+ m = metrics_dict[key]
601
+ writer.writerow([key,
602
+ f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}',
603
+ f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}',
604
+ f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}'])
605
+ print(f"Detailed metrics CSV saved to {detailed_csv_path}")
606
+
607
+ # ===== Plot individual branches (using full t_span trajectories) =====
608
+ self._plot_mouse_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
609
+
610
+ # ===== Plot all branches together =====
611
+ self._plot_mouse_combined(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
612
+
613
+ print(f"Mouse figures saved to {figures_dir}")
614
+
615
+ def _plot_mouse_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
616
+ """Plot each branch separately with timepoint background."""
617
+ n_branches = len(all_trajs)
618
+ branch_names = [f'Branch {i+1}' for i in range(n_branches)]
619
+ branch_colors = ['#B83CFF', '#50B2D7'][:n_branches]
620
+ cmaps = [cmap1, cmap2][:n_branches]
621
+
622
+ # Stack list-of-tensors into [B, T, D] numpy arrays
623
+ all_trajs_np = []
624
+ for traj in all_trajs:
625
+ if isinstance(traj, list):
626
+ traj = torch.stack(traj, dim=1) # list of [B,D] -> [B,T,D]
627
+ all_trajs_np.append(traj.cpu().detach().numpy())
628
+ all_trajs = all_trajs_np
629
+
630
+ # Move timepoint data to numpy
631
+ for key in list(timepoint_data.keys()):
632
+ if torch.is_tensor(timepoint_data[key]):
633
+ timepoint_data[key] = timepoint_data[key].cpu().numpy()
634
+
635
+ # Compute global axis limits
636
+ all_coords = []
637
+ for key in ['t0', 't1', 't2', 't2_1', 't2_2']:
638
+ if key in timepoint_data:
639
+ all_coords.append(timepoint_data[key][:, :2])
640
+ for traj_np in all_trajs:
641
+ all_coords.append(traj_np.reshape(-1, traj_np.shape[-1])[:, :2])
642
+
643
+ all_coords = np.concatenate(all_coords, axis=0)
644
+ x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
645
+ y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
646
+
647
+ # Add margin
648
+ x_margin = 0.05 * (x_max - x_min)
649
+ y_margin = 0.05 * (y_max - y_min)
650
+ x_min -= x_margin
651
+ x_max += x_margin
652
+ y_min -= y_margin
653
+ y_max += y_margin
654
+
655
+ for i, traj in enumerate(all_trajs):
656
+ fig, ax = plt.subplots(figsize=(10, 8))
657
+ cmap = cmaps[i]
658
+ c_end = branch_colors[i]
659
+
660
+ # Plot timepoint background
661
+ t2_key = f't2_{i+1}' if f't2_{i+1}' in timepoint_data else 't2'
662
+ coords_list = [timepoint_data['t0'], timepoint_data['t1'], timepoint_data[t2_key]]
663
+ tp_colors = ['#05009E', '#A19EFF', c_end]
664
+ tp_labels = ["t=0", "t=1", f"t=2 (branch {i+1})"]
665
+
666
+ for coords, color, label in zip(coords_list, tp_colors, tp_labels):
667
+ alpha = 0.8 if color == '#05009E' else 0.6
668
+ ax.scatter(coords[:, 0], coords[:, 1],
669
+ c=color, s=80, alpha=alpha, marker='x',
670
+ label=f'{label} cells', linewidth=1.5)
671
+
672
+ # Plot continuous trajectories with LineCollection for speed
673
+ traj_2d = traj[:, :, :2]
674
+ n_time = traj_2d.shape[1]
675
+ color_vals = cmap(np.linspace(0, 1, n_time))
676
+ segments = []
677
+ seg_colors = []
678
+ for j in range(traj_2d.shape[0]):
679
+ pts = traj_2d[j] # [T, 2]
680
+ segs = np.stack([pts[:-1], pts[1:]], axis=1)
681
+ segments.append(segs)
682
+ seg_colors.append(color_vals[:-1])
683
+ segments = np.concatenate(segments, axis=0)
684
+ seg_colors = np.concatenate(seg_colors, axis=0)
685
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
686
+ ax.add_collection(lc)
687
+
688
+ # Start and end points
689
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
690
+ c='#05009E', s=30, marker='o', label='Trajectory Start',
691
+ zorder=5, edgecolors='white', linewidth=1)
692
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
693
+ c=c_end, s=30, marker='o', label='Trajectory End',
694
+ zorder=5, edgecolors='white', linewidth=1)
695
+
696
+ ax.set_xlim(x_min, x_max)
697
+ ax.set_ylim(y_min, y_max)
698
+ ax.set_xlabel("PC1", fontsize=12)
699
+ ax.set_ylabel("PC2", fontsize=12)
700
+ ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
701
+ ax.grid(True, alpha=0.3)
702
+ ax.legend(loc='upper right', fontsize=12, frameon=False)
703
+
704
+ plt.tight_layout()
705
+ plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
706
+ plt.close()
707
+
708
+ def _plot_mouse_combined(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
709
+ """Plot all branches together."""
710
+ n_branches = len(all_trajs)
711
+ branch_names = [f'Branch {i+1}' for i in range(n_branches)]
712
+ branch_colors = ['#B83CFF', '#50B2D7'][:n_branches]
713
+
714
+ # Build timepoint key/color/label lists depending on branching
715
+ if 't2_1' in timepoint_data:
716
+ tp_keys = ['t0', 't1', 't2_1', 't2_2']
717
+ tp_colors = ['#05009E', '#A19EFF', '#B83CFF', '#50B2D7']
718
+ tp_labels = ['t=0', 't=1', 't=2 (branch 1)', 't=2 (branch 2)']
719
+ else:
720
+ tp_keys = ['t0', 't1', 't2']
721
+ tp_colors = ['#05009E', '#A19EFF', '#B83CFF']
722
+ tp_labels = ['t=0', 't=1', 't=2']
723
+
724
+ # Stack list-of-tensors into [B, T, D] numpy arrays
725
+ all_trajs_np = []
726
+ for traj in all_trajs:
727
+ if isinstance(traj, list):
728
+ traj = torch.stack(traj, dim=1)
729
+ if torch.is_tensor(traj):
730
+ traj = traj.cpu().detach().numpy()
731
+ all_trajs_np.append(traj)
732
+ all_trajs = all_trajs_np
733
+
734
+ # Move timepoint data to numpy
735
+ for key in list(timepoint_data.keys()):
736
+ if torch.is_tensor(timepoint_data[key]):
737
+ timepoint_data[key] = timepoint_data[key].cpu().numpy()
738
+
739
+ fig, ax = plt.subplots(figsize=(12, 10))
740
+
741
+ # Plot timepoint background
742
+ for idx, (t_key, color, label) in enumerate(zip(
743
+ tp_keys,
744
+ tp_colors,
745
+ tp_labels
746
+ )):
747
+ if t_key in timepoint_data:
748
+ coords = timepoint_data[t_key]
749
+ ax.scatter(coords[:, 0], coords[:, 1],
750
+ c=color, s=80, alpha=0.4, marker='x',
751
+ label=f'{label} cells', linewidth=1.5)
752
+
753
+ # Plot trajectories with color gradients
754
+ cmaps = [cmap1, cmap2]
755
+ for i, traj in enumerate(all_trajs):
756
+ traj_2d = traj[:, :, :2]
757
+ c_end = branch_colors[i]
758
+ cmap = cmaps[i]
759
+ n_time = traj_2d.shape[1]
760
+ color_vals = cmap(np.linspace(0, 1, n_time))
761
+ segments = []
762
+ seg_colors = []
763
+ for j in range(traj_2d.shape[0]):
764
+ pts = traj_2d[j]
765
+ segs = np.stack([pts[:-1], pts[1:]], axis=1)
766
+ segments.append(segs)
767
+ seg_colors.append(color_vals[:-1])
768
+ segments = np.concatenate(segments, axis=0)
769
+ seg_colors = np.concatenate(seg_colors, axis=0)
770
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
771
+ ax.add_collection(lc)
772
+
773
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
774
+ c='#05009E', s=30, marker='o',
775
+ label=f'{branch_names[i]} Start',
776
+ zorder=5, edgecolors='white', linewidth=1)
777
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
778
+ c=c_end, s=30, marker='o',
779
+ label=f'{branch_names[i]} End',
780
+ zorder=5, edgecolors='white', linewidth=1)
781
+
782
+ ax.set_xlabel("PC1", fontsize=14)
783
+ ax.set_ylabel("PC2", fontsize=14)
784
+ ax.set_title("All Branch Trajectories with Timepoint Background",
785
+ fontsize=16, weight='bold')
786
+ ax.grid(True, alpha=0.3)
787
+ ax.legend(loc='upper right', fontsize=12, frameon=False)
788
+
789
+ plt.tight_layout()
790
+ plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
791
+ plt.close()
792
+
793
+
794
+ class FlowNetTestClonidine(BranchFlowNetTrainBase):
795
+ """Test class for Clonidine perturbation experiment (1 or 2 branches)."""
796
+
797
+ def test_step(self, batch, batch_idx):
798
+ # Handle both dict and tuple batch formats from CombinedLoader
799
+ if isinstance(batch, dict) and "test_samples" in batch:
800
+ # New format: {"test_samples": {...}, "metric_samples": {...}}
801
+ main_batch = batch["test_samples"]
802
+ elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
803
+ # Old format with nested structure
804
+ test_samples = batch[0]
805
+ if isinstance(test_samples, dict) and "test_samples" in test_samples:
806
+ main_batch = test_samples["test_samples"][0]
807
+ else:
808
+ main_batch = test_samples
809
+ else:
810
+ # Fallback
811
+ main_batch = batch
812
+
813
+ # Get timepoint data
814
+ timepoint_data = self.trainer.datamodule.get_timepoint_data()
815
+ device = main_batch["x0"][0].device
816
+
817
+ # Use val x0 as initial conditions
818
+ x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
819
+ t_span = torch.linspace(0, 1, 100).to(device)
820
+
821
+ # Define color schemes for clonidine (2 branches)
822
+ custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
823
+ custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
824
+ custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
825
+ custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
826
+
827
+ all_trajs = []
828
+ all_endpoints = []
829
+
830
+ for i, flow_net in enumerate(self.flow_nets):
831
+ node = NeuralODE(
832
+ flow_model_torch_wrapper(flow_net),
833
+ solver="euler",
834
+ sensitivity="adjoint",
835
+ )
836
+
837
+ with torch.no_grad():
838
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
839
+
840
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
841
+ all_trajs.append(traj)
842
+ all_endpoints.append(traj[:, -1, :])
843
+
844
+ # Run 5 trials with random subsampling for robust metrics
845
+ n_trials = 5
846
+ n_branches = len(self.flow_nets)
847
+
848
+ # Gather per-branch ground truth
849
+ gt_data_per_branch = []
850
+ for i in range(n_branches):
851
+ if n_branches == 1:
852
+ key = 't1'
853
+ else:
854
+ key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
855
+ gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
856
+ gt_all = torch.cat(gt_data_per_branch, dim=0)
857
+
858
+ # Per-branch metrics (5 trials)
859
+ metrics_dict = {}
860
+ for i in range(n_branches):
861
+ w1_br, w2_br, mmd_br = [], [], []
862
+ pred = all_endpoints[i]
863
+ gt = gt_data_per_branch[i]
864
+ for trial in range(n_trials):
865
+ n_min = min(pred.shape[0], gt.shape[0])
866
+ perm_pred = torch.randperm(pred.shape[0])[:n_min]
867
+ perm_gt = torch.randperm(gt.shape[0])[:n_min]
868
+ m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
869
+ w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
870
+ metrics_dict[f"branch_{i+1}"] = {
871
+ "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
872
+ "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
873
+ "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
874
+ }
875
+ self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
876
+ print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
877
+ f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
878
+ f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
879
+
880
+ # Combined metrics (5 trials)
881
+ pred_all = torch.cat(all_endpoints, dim=0)
882
+ w1_trials, w2_trials, mmd_trials = [], [], []
883
+ for trial in range(n_trials):
884
+ n_min = min(pred_all.shape[0], gt_all.shape[0])
885
+ perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
886
+ perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
887
+ m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
888
+ w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
889
+
890
+ w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
891
+ w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
892
+ mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
893
+ self.log("test/W1_t1_combined", w1_mean, on_epoch=True)
894
+ self.log("test/W2_t1_combined", w2_mean, on_epoch=True)
895
+ self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True)
896
+ metrics_dict['t1_combined'] = {
897
+ "W1_mean": float(w1_mean), "W1_std": float(w1_std),
898
+ "W2_mean": float(w2_mean), "W2_std": float(w2_std),
899
+ "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
900
+ "n_trials": n_trials,
901
+ }
902
+ print(f"\n=== Combined @ t1 ===")
903
+ print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
904
+ print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
905
+ print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
906
+
907
+ # Create results directory structure
908
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
909
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
910
+ figures_dir = f'{results_dir}/figures'
911
+ os.makedirs(figures_dir, exist_ok=True)
912
+
913
+ # Save metrics to JSON
914
+ metrics_path = f'{results_dir}/metrics.json'
915
+ with open(metrics_path, 'w') as f:
916
+ json.dump(metrics_dict, f, indent=2)
917
+ print(f"Metrics saved to {metrics_path}")
918
+
919
+ # Save detailed metrics to CSV
920
+ detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
921
+ with open(detailed_csv_path, 'w', newline='') as csvfile:
922
+ writer = csv.writer(csvfile)
923
+ writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
924
+ for key in sorted(metrics_dict.keys()):
925
+ m = metrics_dict[key]
926
+ writer.writerow([key,
927
+ f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
928
+ f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
929
+ f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
930
+ print(f"Detailed metrics CSV saved to {detailed_csv_path}")
931
+
932
+ # ===== Plot branches =====
933
+ self._plot_clonidine_branches(all_trajs, timepoint_data, figures_dir, custom_cmap_1, custom_cmap_2)
934
+ self._plot_clonidine_combined(all_trajs, timepoint_data, figures_dir)
935
+
936
+ print(f"Clonidine figures saved to {figures_dir}")
937
+
938
+ def _plot_clonidine_branches(self, all_trajs, timepoint_data, save_dir, cmap1, cmap2):
939
+ """Plot each branch separately."""
940
+ branch_names = ['Branch 1', 'Branch 2']
941
+ branch_colors = ['#B83CFF', '#50B2D7']
942
+ cmaps = [cmap1, cmap2]
943
+
944
+ # Compute global axis limits – handle single vs multi branch keys
945
+ all_coords = []
946
+ if 't1_1' in timepoint_data:
947
+ tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))]
948
+ else:
949
+ tp_keys = ['t0', 't1']
950
+ for key in tp_keys:
951
+ all_coords.append(timepoint_data[key][:, :2])
952
+ for traj in all_trajs:
953
+ all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2])
954
+
955
+ all_coords = np.concatenate(all_coords, axis=0)
956
+ x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
957
+ y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
958
+
959
+ x_margin = 0.05 * (x_max - x_min)
960
+ y_margin = 0.05 * (y_max - y_min)
961
+ x_min -= x_margin
962
+ x_max += x_margin
963
+ y_min -= y_margin
964
+ y_max += y_margin
965
+
966
+ for i, traj in enumerate(all_trajs):
967
+ fig, ax = plt.subplots(figsize=(10, 8))
968
+ c_end = branch_colors[i]
969
+
970
+ # Plot timepoint background
971
+ t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
972
+ coords_list = [timepoint_data['t0'], timepoint_data[t1_key]]
973
+ tp_colors = ['#05009E', c_end]
974
+ t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1"
975
+ tp_labels = ["t=0", t1_label]
976
+
977
+ for coords, color, label in zip(coords_list, tp_colors, tp_labels):
978
+ ax.scatter(coords[:, 0], coords[:, 1],
979
+ c=color, s=80, alpha=0.4, marker='x',
980
+ label=f'{label} cells', linewidth=1.5)
981
+
982
+ # Plot continuous trajectories with LineCollection for speed
983
+ traj_2d = traj[:, :, :2]
984
+ n_time = traj_2d.shape[1]
985
+ color_vals = cmaps[i](np.linspace(0, 1, n_time))
986
+ segments = []
987
+ seg_colors = []
988
+ for j in range(traj_2d.shape[0]):
989
+ pts = traj_2d[j]
990
+ segs = np.stack([pts[:-1], pts[1:]], axis=1)
991
+ segments.append(segs)
992
+ seg_colors.append(color_vals[:-1])
993
+ segments = np.concatenate(segments, axis=0)
994
+ seg_colors = np.concatenate(seg_colors, axis=0)
995
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
996
+ ax.add_collection(lc)
997
+
998
+ # Start and end points
999
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
1000
+ c='#05009E', s=30, marker='o', label='Trajectory Start',
1001
+ zorder=5, edgecolors='white', linewidth=1)
1002
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
1003
+ c=c_end, s=30, marker='o', label='Trajectory End',
1004
+ zorder=5, edgecolors='white', linewidth=1)
1005
+
1006
+ ax.set_xlim(x_min, x_max)
1007
+ ax.set_ylim(y_min, y_max)
1008
+ ax.set_xlabel("PC1", fontsize=12)
1009
+ ax.set_ylabel("PC2", fontsize=12)
1010
+ ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
1011
+ ax.grid(True, alpha=0.3)
1012
+ ax.legend(loc='upper right', fontsize=16, frameon=False)
1013
+
1014
+ plt.tight_layout()
1015
+ plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
1016
+ plt.close()
1017
+
1018
+ def _plot_clonidine_combined(self, all_trajs, timepoint_data, save_dir):
1019
+ """Plot all branches together."""
1020
+ branch_names = ['Branch 1', 'Branch 2']
1021
+ branch_colors = ['#B83CFF', '#50B2D7']
1022
+
1023
+ fig, ax = plt.subplots(figsize=(12, 10))
1024
+
1025
+ # Build timepoint keys/colors/labels depending on single vs multi branch
1026
+ if 't1_1' in timepoint_data:
1027
+ tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))]
1028
+ tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))]
1029
+ else:
1030
+ tp_keys = ['t0', 't1']
1031
+ tp_labels_list = ['t=0', 't=1']
1032
+ tp_colors = ['#05009E', '#B83CFF', '#50B2D7'][:len(tp_keys)]
1033
+
1034
+ # Plot timepoint background
1035
+ for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list):
1036
+ coords = timepoint_data[t_key]
1037
+ ax.scatter(coords[:, 0], coords[:, 1],
1038
+ c=color, s=80, alpha=0.4, marker='x',
1039
+ label=f'{label} cells', linewidth=1.5)
1040
+
1041
+ # Plot trajectories with color gradients
1042
+ custom_colors_1 = ["#05009E", "#A19EFF", "#B83CFF"]
1043
+ custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
1044
+ cmaps = [
1045
+ LinearSegmentedColormap.from_list("clon_cmap1", custom_colors_1),
1046
+ LinearSegmentedColormap.from_list("clon_cmap2", custom_colors_2),
1047
+ ]
1048
+ for i, traj in enumerate(all_trajs):
1049
+ traj_2d = traj[:, :, :2]
1050
+ c_end = branch_colors[i]
1051
+ cmap = cmaps[i]
1052
+ n_time = traj_2d.shape[1]
1053
+ color_vals = cmap(np.linspace(0, 1, n_time))
1054
+ segments = []
1055
+ seg_colors = []
1056
+ for j in range(traj_2d.shape[0]):
1057
+ pts = traj_2d[j]
1058
+ segs = np.stack([pts[:-1], pts[1:]], axis=1)
1059
+ segments.append(segs)
1060
+ seg_colors.append(color_vals[:-1])
1061
+ segments = np.concatenate(segments, axis=0)
1062
+ seg_colors = np.concatenate(seg_colors, axis=0)
1063
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
1064
+ ax.add_collection(lc)
1065
+
1066
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
1067
+ c='#05009E', s=30, marker='o',
1068
+ label=f'{branch_names[i]} Start',
1069
+ zorder=5, edgecolors='white', linewidth=1)
1070
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
1071
+ c=c_end, s=30, marker='o',
1072
+ label=f'{branch_names[i]} End',
1073
+ zorder=5, edgecolors='white', linewidth=1)
1074
+
1075
+ ax.set_xlabel("PC1", fontsize=14)
1076
+ ax.set_ylabel("PC2", fontsize=14)
1077
+ ax.set_title("All Branch Trajectories with Timepoint Background",
1078
+ fontsize=16, weight='bold')
1079
+ ax.grid(True, alpha=0.3)
1080
+ ax.legend(loc='upper right', fontsize=12, frameon=False)
1081
+
1082
+ plt.tight_layout()
1083
+ plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
1084
+ plt.close()
1085
+
1086
+
1087
+ class FlowNetTestTrametinib(BranchFlowNetTrainBase):
1088
+ """Test class for Trametinib perturbation experiment (1 or 3 branches)."""
1089
+
1090
+ def test_step(self, batch, batch_idx):
1091
+ # Handle both dict and tuple batch formats from CombinedLoader
1092
+ if isinstance(batch, dict) and "test_samples" in batch:
1093
+ # New format: {"test_samples": {...}, "metric_samples": {...}}
1094
+ main_batch = batch["test_samples"]
1095
+ elif isinstance(batch, (list, tuple)) and len(batch) >= 1:
1096
+ # Old format with nested structure
1097
+ test_samples = batch[0]
1098
+ if isinstance(test_samples, dict) and "test_samples" in test_samples:
1099
+ main_batch = test_samples["test_samples"][0]
1100
+ else:
1101
+ main_batch = test_samples
1102
+ else:
1103
+ # Fallback
1104
+ main_batch = batch
1105
+
1106
+ # Get timepoint data
1107
+ timepoint_data = self.trainer.datamodule.get_timepoint_data()
1108
+ device = main_batch["x0"][0].device
1109
+
1110
+ # Use val x0 as initial conditions
1111
+ x0 = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
1112
+ t_span = torch.linspace(0, 1, 100).to(device)
1113
+
1114
+ # Define color schemes for trametinib (3 branches)
1115
+ custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"]
1116
+ custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
1117
+ custom_colors_3 = ["#05009E", "#A19EFF", "#B83CFF"]
1118
+ custom_cmap_1 = LinearSegmentedColormap.from_list("cmap1", custom_colors_1)
1119
+ custom_cmap_2 = LinearSegmentedColormap.from_list("cmap2", custom_colors_2)
1120
+ custom_cmap_3 = LinearSegmentedColormap.from_list("cmap3", custom_colors_3)
1121
+
1122
+ all_trajs = []
1123
+ all_endpoints = []
1124
+
1125
+ for i, flow_net in enumerate(self.flow_nets):
1126
+ node = NeuralODE(
1127
+ flow_model_torch_wrapper(flow_net),
1128
+ solver="euler",
1129
+ sensitivity="adjoint",
1130
+ )
1131
+
1132
+ with torch.no_grad():
1133
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
1134
+
1135
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
1136
+ all_trajs.append(traj)
1137
+ all_endpoints.append(traj[:, -1, :])
1138
+
1139
+ # Run 5 trials with random subsampling for robust metrics
1140
+ n_trials = 5
1141
+ n_branches = len(self.flow_nets)
1142
+
1143
+ # Gather per-branch ground truth
1144
+ gt_data_per_branch = []
1145
+ for i in range(n_branches):
1146
+ if n_branches == 1:
1147
+ key = 't1'
1148
+ else:
1149
+ key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
1150
+ gt_data_per_branch.append(torch.tensor(timepoint_data[key], dtype=torch.float32))
1151
+ gt_all = torch.cat(gt_data_per_branch, dim=0)
1152
+
1153
+ # Per-branch metrics (5 trials)
1154
+ metrics_dict = {}
1155
+ for i in range(n_branches):
1156
+ w1_br, w2_br, mmd_br = [], [], []
1157
+ pred = all_endpoints[i]
1158
+ gt = gt_data_per_branch[i]
1159
+ for trial in range(n_trials):
1160
+ n_min = min(pred.shape[0], gt.shape[0])
1161
+ perm_pred = torch.randperm(pred.shape[0])[:n_min]
1162
+ perm_gt = torch.randperm(gt.shape[0])[:n_min]
1163
+ m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
1164
+ w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
1165
+ metrics_dict[f"branch_{i+1}"] = {
1166
+ "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
1167
+ "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
1168
+ "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
1169
+ }
1170
+ self.log(f"test/W1_branch{i+1}", np.mean(w1_br), on_epoch=True)
1171
+ print(f"Branch {i+1} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
1172
+ f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
1173
+ f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
1174
+
1175
+ # Combined metrics (5 trials)
1176
+ pred_all = torch.cat(all_endpoints, dim=0)
1177
+ w1_trials, w2_trials, mmd_trials = [], [], []
1178
+ for trial in range(n_trials):
1179
+ n_min = min(pred_all.shape[0], gt_all.shape[0])
1180
+ perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
1181
+ perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
1182
+ m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
1183
+ w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
1184
+
1185
+ w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
1186
+ w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
1187
+ mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
1188
+ self.log("test/W1_t1_combined", w1_mean, on_epoch=True)
1189
+ self.log("test/W2_t1_combined", w2_mean, on_epoch=True)
1190
+ self.log("test/MMD_t1_combined", mmd_mean, on_epoch=True)
1191
+ metrics_dict['t1_combined'] = {
1192
+ "W1_mean": float(w1_mean), "W1_std": float(w1_std),
1193
+ "W2_mean": float(w2_mean), "W2_std": float(w2_std),
1194
+ "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
1195
+ "n_trials": n_trials,
1196
+ }
1197
+ print(f"\n=== Combined @ t1 ===")
1198
+ print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
1199
+ print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
1200
+ print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
1201
+
1202
+ # Create results directory structure
1203
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
1204
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
1205
+ figures_dir = f'{results_dir}/figures'
1206
+ os.makedirs(figures_dir, exist_ok=True)
1207
+
1208
+ # Save metrics to JSON
1209
+ metrics_path = f'{results_dir}/metrics.json'
1210
+ with open(metrics_path, 'w') as f:
1211
+ json.dump(metrics_dict, f, indent=2)
1212
+ print(f"Metrics saved to {metrics_path}")
1213
+
1214
+ # Save detailed metrics to CSV
1215
+ detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
1216
+ with open(detailed_csv_path, 'w', newline='') as csvfile:
1217
+ writer = csv.writer(csvfile)
1218
+ writer.writerow(['Metric_Group', 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std'])
1219
+ for key in sorted(metrics_dict.keys()):
1220
+ m = metrics_dict[key]
1221
+ writer.writerow([key,
1222
+ f'{m.get("W1_mean", m.get("W1", 0)):.6f}', f'{m.get("W1_std", 0):.6f}',
1223
+ f'{m.get("W2_mean", m.get("W2", 0)):.6f}', f'{m.get("W2_std", 0):.6f}',
1224
+ f'{m.get("MMD_mean", m.get("MMD", 0)):.6f}', f'{m.get("MMD_std", 0):.6f}'])
1225
+ print(f"Detailed metrics CSV saved to {detailed_csv_path}")
1226
+
1227
+ # ===== Plot branches =====
1228
+ self._plot_trametinib_branches(all_trajs, timepoint_data, figures_dir,
1229
+ custom_cmap_1, custom_cmap_2, custom_cmap_3)
1230
+ self._plot_trametinib_combined(all_trajs, timepoint_data, figures_dir)
1231
+
1232
+ print(f"Trametinib figures saved to {figures_dir}")
1233
+
1234
+ def _plot_trametinib_branches(self, all_trajs, timepoint_data, save_dir,
1235
+ cmap1, cmap2, cmap3):
1236
+ """Plot each branch separately."""
1237
+ branch_names = ['Branch 1', 'Branch 2', 'Branch 3']
1238
+ branch_colors = ['#9793F8', '#50B2D7', '#B83CFF']
1239
+ cmaps = [cmap1, cmap2, cmap3]
1240
+
1241
+ # Compute global axis limits – handle single vs multi branch keys
1242
+ all_coords = []
1243
+ if 't1_1' in timepoint_data:
1244
+ tp_keys = ['t0'] + [f't1_{i+1}' for i in range(len(all_trajs))]
1245
+ else:
1246
+ tp_keys = ['t0', 't1']
1247
+ for key in tp_keys:
1248
+ all_coords.append(timepoint_data[key][:, :2])
1249
+ for traj in all_trajs:
1250
+ all_coords.append(traj.reshape(-1, traj.shape[-1])[:, :2])
1251
+
1252
+ all_coords = np.concatenate(all_coords, axis=0)
1253
+ x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
1254
+ y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
1255
+
1256
+ x_margin = 0.05 * (x_max - x_min)
1257
+ y_margin = 0.05 * (y_max - y_min)
1258
+ x_min -= x_margin
1259
+ x_max += x_margin
1260
+ y_min -= y_margin
1261
+ y_max += y_margin
1262
+
1263
+ for i, traj in enumerate(all_trajs):
1264
+ fig, ax = plt.subplots(figsize=(10, 8))
1265
+ c_end = branch_colors[i]
1266
+
1267
+ # Plot timepoint background
1268
+ t1_key = f't1_{i+1}' if f't1_{i+1}' in timepoint_data else 't1'
1269
+ coords_list = [timepoint_data['t0'], timepoint_data[t1_key]]
1270
+ tp_colors = ['#05009E', c_end]
1271
+ t1_label = f"t=1 (branch {i+1})" if len(all_trajs) > 1 else "t=1"
1272
+ tp_labels = ["t=0", t1_label]
1273
+
1274
+ for coords, color, label in zip(coords_list, tp_colors, tp_labels):
1275
+ ax.scatter(coords[:, 0], coords[:, 1],
1276
+ c=color, s=80, alpha=0.4, marker='x',
1277
+ label=f'{label} cells', linewidth=1.5)
1278
+
1279
+ # Plot continuous trajectories with LineCollection for speed
1280
+ traj_2d = traj[:, :, :2]
1281
+ n_time = traj_2d.shape[1]
1282
+ color_vals = cmaps[i](np.linspace(0, 1, n_time))
1283
+ segments = []
1284
+ seg_colors = []
1285
+ for j in range(traj_2d.shape[0]):
1286
+ pts = traj_2d[j]
1287
+ segs = np.stack([pts[:-1], pts[1:]], axis=1)
1288
+ segments.append(segs)
1289
+ seg_colors.append(color_vals[:-1])
1290
+ segments = np.concatenate(segments, axis=0)
1291
+ seg_colors = np.concatenate(seg_colors, axis=0)
1292
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
1293
+ ax.add_collection(lc)
1294
+
1295
+ # Start and end points
1296
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
1297
+ c='#05009E', s=30, marker='o', label='Trajectory Start',
1298
+ zorder=5, edgecolors='white', linewidth=1)
1299
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
1300
+ c=c_end, s=30, marker='o', label='Trajectory End',
1301
+ zorder=5, edgecolors='white', linewidth=1)
1302
+
1303
+ ax.set_xlim(x_min, x_max)
1304
+ ax.set_ylim(y_min, y_max)
1305
+ ax.set_xlabel("PC1", fontsize=12)
1306
+ ax.set_ylabel("PC2", fontsize=12)
1307
+ ax.set_title(f"{branch_names[i]}: Trajectories with Timepoint Background", fontsize=14)
1308
+ ax.grid(True, alpha=0.3)
1309
+ ax.legend(loc='upper right', fontsize=16, frameon=False)
1310
+
1311
+ plt.tight_layout()
1312
+ plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
1313
+ plt.close()
1314
+
1315
+ def _plot_trametinib_combined(self, all_trajs, timepoint_data, save_dir):
1316
+ """Plot all 3 branches together."""
1317
+ branch_names = ['Branch 1', 'Branch 2', 'Branch 3']
1318
+ branch_colors = ['#9793F8', '#50B2D7', '#B83CFF']
1319
+
1320
+ fig, ax = plt.subplots(figsize=(12, 10))
1321
+
1322
+ # Build timepoint keys/colors/labels depending on single vs multi branch
1323
+ if 't1_1' in timepoint_data:
1324
+ tp_keys = ['t0'] + [f't1_{j+1}' for j in range(len(all_trajs))]
1325
+ tp_labels_list = ['t=0'] + [f't=1 (branch {j+1})' for j in range(len(all_trajs))]
1326
+ else:
1327
+ tp_keys = ['t0', 't1']
1328
+ tp_labels_list = ['t=0', 't=1']
1329
+ tp_colors = ['#05009E', '#9793F8', '#50B2D7', '#B83CFF'][:len(tp_keys)]
1330
+
1331
+ # Plot timepoint background
1332
+ for t_key, color, label in zip(tp_keys, tp_colors, tp_labels_list):
1333
+ coords = timepoint_data[t_key]
1334
+ ax.scatter(coords[:, 0], coords[:, 1],
1335
+ c=color, s=80, alpha=0.4, marker='x',
1336
+ label=f'{label} cells', linewidth=1.5)
1337
+
1338
+ # Plot trajectories with color gradients
1339
+ custom_colors_1 = ["#05009E", "#A19EFF", "#9793F8"]
1340
+ custom_colors_2 = ["#05009E", "#A19EFF", "#50B2D7"]
1341
+ custom_colors_3 = ["#05009E", "#A19EFF", "#D577FF"]
1342
+ cmaps = [
1343
+ LinearSegmentedColormap.from_list("tram_cmap1", custom_colors_1),
1344
+ LinearSegmentedColormap.from_list("tram_cmap2", custom_colors_2),
1345
+ LinearSegmentedColormap.from_list("tram_cmap3", custom_colors_3),
1346
+ ]
1347
+ for i, traj in enumerate(all_trajs):
1348
+ traj_2d = traj[:, :, :2]
1349
+ c_end = branch_colors[i]
1350
+ cmap = cmaps[i]
1351
+ n_time = traj_2d.shape[1]
1352
+ color_vals = cmap(np.linspace(0, 1, n_time))
1353
+ segments = []
1354
+ seg_colors = []
1355
+ for j in range(traj_2d.shape[0]):
1356
+ pts = traj_2d[j]
1357
+ segs = np.stack([pts[:-1], pts[1:]], axis=1)
1358
+ segments.append(segs)
1359
+ seg_colors.append(color_vals[:-1])
1360
+ segments = np.concatenate(segments, axis=0)
1361
+ seg_colors = np.concatenate(seg_colors, axis=0)
1362
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
1363
+ ax.add_collection(lc)
1364
+
1365
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
1366
+ c='#05009E', s=30, marker='o',
1367
+ label=f'{branch_names[i]} Start',
1368
+ zorder=5, edgecolors='white', linewidth=1)
1369
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
1370
+ c=c_end, s=30, marker='o',
1371
+ label=f'{branch_names[i]} End',
1372
+ zorder=5, edgecolors='white', linewidth=1)
1373
+
1374
+ ax.set_xlabel("PC1", fontsize=14)
1375
+ ax.set_ylabel("PC2", fontsize=14)
1376
+ ax.set_title("All Branch Trajectories with Timepoint Background",
1377
+ fontsize=16, weight='bold')
1378
+ ax.grid(True, alpha=0.3)
1379
+ ax.legend(loc='upper right', fontsize=12, frameon=False)
1380
+
1381
+ plt.tight_layout()
1382
+ plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
1383
+ plt.close()
1384
+
1385
+ class FlowNetTestVeres(GrowthNetTrain):
1386
+ """Test class for Veres pancreatic endocrinogenesis experiment (3 or 5 branches)."""
1387
+
1388
+ def test_step(self, batch, batch_idx):
1389
+ # Handle both tuple and dict batch formats from CombinedLoader
1390
+ if isinstance(batch, dict):
1391
+ main_batch = batch["test_samples"][0]
1392
+ metric_batch = batch["metric_samples"][0]
1393
+ else:
1394
+ # batch is a list/tuple
1395
+ if isinstance(batch[0], dict):
1396
+ # batch[0] contains the dict with test_samples and metric_samples
1397
+ main_batch = batch[0]["test_samples"][0]
1398
+ metric_batch = batch[0]["metric_samples"][0]
1399
+ else:
1400
+ # batch is a tuple: (test_samples, metric_samples)
1401
+ main_batch = batch[0][0]
1402
+ metric_batch = batch[1][0]
1403
+
1404
+ # Get timepoint data (full datasets, not just val split)
1405
+ timepoint_data = self.trainer.datamodule.get_timepoint_data()
1406
+ device = main_batch["x0"][0].device
1407
+
1408
+ # Use val x0 as initial conditions
1409
+ x0_all = self.trainer.datamodule.val_dataloaders["x0"].dataset.tensors[0].to(device)
1410
+ w0_all = torch.ones(x0_all.shape[0], 1, dtype=torch.float32).to(device)
1411
+ full_batch = {"x0": (x0_all, w0_all)}
1412
+
1413
+ time_points, all_endpoints, all_trajs, mass_over_time, energy_over_time, weights_over_time = self.get_mass_and_position(full_batch, metric_batch)
1414
+
1415
+ n_branches = len(self.flow_nets)
1416
+
1417
+ # trajectory time grid
1418
+ t_span = torch.linspace(0, 1, 101).to(device)
1419
+
1420
+ # `all_trajs` returned from `get_mass_and_position` is expected to be a list where each
1421
+ # element is a sequence of per-timepoint tensors for that branch (shape [B, D] each).
1422
+ # Convert each branch to [T, B, D] then to [B, T, D] for downstream processing.
1423
+ trajs_TBD = [torch.stack(branch_list, dim=0) for branch_list in all_trajs] # each is [T, B, D]
1424
+ trajs_BTD = [t.permute(1, 0, 2) for t in trajs_TBD] # each -> [B, T, D]
1425
+
1426
+ all_trajs = []
1427
+ all_endpoints = []
1428
+ # will store per-branch intermediate frames: each entry -> tensor [B, n_intermediate, D]
1429
+ all_intermediates = []
1430
+
1431
+ for traj in trajs_BTD:
1432
+ # traj is [B, T, D]
1433
+ # optionally inverse-transform if whitened
1434
+ if self.whiten:
1435
+ traj_np = traj.detach().cpu().numpy()
1436
+ n_samples, n_time, n_dims = traj_np.shape
1437
+ traj_flat = traj_np.reshape(-1, n_dims)
1438
+ traj_inv_flat = self.trainer.datamodule.scaler.inverse_transform(traj_flat)
1439
+ traj_inv = traj_inv_flat.reshape(n_samples, n_time, n_dims)
1440
+ traj = torch.tensor(traj_inv, dtype=torch.float32)
1441
+
1442
+ all_trajs.append(traj)
1443
+
1444
+ # Collect six evenly spaced intermediate frames between t=0 and t=1 (exclude endpoints)
1445
+ n_T = traj.shape[1]
1446
+ # choose 8 points including endpoints -> take inner 6 as intermediates
1447
+ inter_times = np.linspace(0.0, 1.0, 8)[1:-1] # 6 values
1448
+ inter_indices = [int(round(t * (n_T - 1))) for t in inter_times]
1449
+ # stack per-branch intermediate frames -> [B, 6, D]
1450
+ intermediates = torch.stack([traj[:, idx, :] for idx in inter_indices], dim=1)
1451
+ all_intermediates.append(intermediates)
1452
+
1453
+ # Final endpoints (t=1)
1454
+ all_endpoints.append(traj[:, -1, :])
1455
+
1456
+ # Run 5 trials with random subsampling for robust metrics
1457
+ n_trials = 5
1458
+ metrics_dict = {}
1459
+
1460
+ # --- Intermediate timepoints (t1-t6) combined metrics ---
1461
+ intermediate_keys = sorted([k for k in timepoint_data.keys()
1462
+ if k.startswith('t') and '_' not in k and k != 't0'])
1463
+
1464
+ if intermediate_keys:
1465
+ n_evals = min(6, len(intermediate_keys))
1466
+ for j in range(n_evals):
1467
+ intermediate_key = intermediate_keys[j]
1468
+ true_data_intermediate = torch.tensor(timepoint_data[intermediate_key], dtype=torch.float32)
1469
+
1470
+ # Gather predicted intermediates across all branches
1471
+ raw_intermediates = [branch[:, j, :] for branch in all_intermediates]
1472
+ all_raw_concat = torch.cat(raw_intermediates, dim=0).cpu() # [n_branches*B, D]
1473
+
1474
+ w1_t, w2_t, mmd_t = [], [], []
1475
+ w1_t_full, w2_t_full, mmd_t_full = [], [], []
1476
+ for trial in range(n_trials):
1477
+ n_min = min(all_raw_concat.shape[0], true_data_intermediate.shape[0])
1478
+ perm_pred = torch.randperm(all_raw_concat.shape[0])[:n_min]
1479
+ perm_gt = torch.randperm(true_data_intermediate.shape[0])[:n_min]
1480
+ # 2D metrics (PC1-PC2)
1481
+ m = compute_distribution_distances(
1482
+ all_raw_concat[perm_pred, :2], true_data_intermediate[perm_gt, :2])
1483
+ w1_t.append(m["W1"]); w2_t.append(m["W2"]); mmd_t.append(m["MMD"])
1484
+ # Full-dimensional metrics (all PCs)
1485
+ m_full = compute_distribution_distances(
1486
+ all_raw_concat[perm_pred], true_data_intermediate[perm_gt])
1487
+ w1_t_full.append(m_full["W1"]); w2_t_full.append(m_full["W2"]); mmd_t_full.append(m_full["MMD"])
1488
+
1489
+ metrics_dict[f'{intermediate_key}_combined'] = {
1490
+ "W1_mean": float(np.mean(w1_t)), "W1_std": float(np.std(w1_t, ddof=1)),
1491
+ "W2_mean": float(np.mean(w2_t)), "W2_std": float(np.std(w2_t, ddof=1)),
1492
+ "MMD_mean": float(np.mean(mmd_t)), "MMD_std": float(np.std(mmd_t, ddof=1)),
1493
+ "W1_full_mean": float(np.mean(w1_t_full)), "W1_full_std": float(np.std(w1_t_full, ddof=1)),
1494
+ "W2_full_mean": float(np.mean(w2_t_full)), "W2_full_std": float(np.std(w2_t_full, ddof=1)),
1495
+ "MMD_full_mean": float(np.mean(mmd_t_full)), "MMD_full_std": float(np.std(mmd_t_full, ddof=1)),
1496
+ }
1497
+ self.log(f"test/W1_{intermediate_key}_combined", np.mean(w1_t), on_epoch=True)
1498
+ self.log(f"test/W1_full_{intermediate_key}_combined", np.mean(w1_t_full), on_epoch=True)
1499
+ print(f"{intermediate_key} combined — W1: {np.mean(w1_t):.6f}±{np.std(w1_t, ddof=1):.6f}, "
1500
+ f"W2: {np.mean(w2_t):.6f}±{np.std(w2_t, ddof=1):.6f}, "
1501
+ f"MMD: {np.mean(mmd_t):.6f}±{np.std(mmd_t, ddof=1):.6f}")
1502
+ print(f"{intermediate_key} combined (full) — W1: {np.mean(w1_t_full):.6f}±{np.std(w1_t_full, ddof=1):.6f}, "
1503
+ f"W2: {np.mean(w2_t_full):.6f}±{np.std(w2_t_full, ddof=1):.6f}, "
1504
+ f"MMD: {np.mean(mmd_t_full):.6f}±{np.std(mmd_t_full, ddof=1):.6f}")
1505
+
1506
+ # --- Final timepoint per-branch metrics ---
1507
+ gt_keys = sorted([k for k in timepoint_data.keys() if k.startswith('t7_')])
1508
+ for i, endpoints in enumerate(all_endpoints):
1509
+ true_data_key = f"t7_{i}"
1510
+ if true_data_key not in timepoint_data:
1511
+ print(f"Warning: {true_data_key} not found in timepoint_data")
1512
+ continue
1513
+ gt = torch.tensor(timepoint_data[true_data_key], dtype=torch.float32)
1514
+ pred = endpoints.cpu()
1515
+
1516
+ w1_br, w2_br, mmd_br = [], [], []
1517
+ w1_br_full, w2_br_full, mmd_br_full = [], [], []
1518
+ for trial in range(n_trials):
1519
+ n_min = min(pred.shape[0], gt.shape[0])
1520
+ perm_pred = torch.randperm(pred.shape[0])[:n_min]
1521
+ perm_gt = torch.randperm(gt.shape[0])[:n_min]
1522
+ # 2D metrics (PC1-PC2)
1523
+ m = compute_distribution_distances(pred[perm_pred, :2], gt[perm_gt, :2])
1524
+ w1_br.append(m["W1"]); w2_br.append(m["W2"]); mmd_br.append(m["MMD"])
1525
+ # Full-dimensional metrics (all PCs)
1526
+ m_full = compute_distribution_distances(pred[perm_pred], gt[perm_gt])
1527
+ w1_br_full.append(m_full["W1"]); w2_br_full.append(m_full["W2"]); mmd_br_full.append(m_full["MMD"])
1528
+
1529
+ metrics_dict[f"branch_{i}"] = {
1530
+ "W1_mean": float(np.mean(w1_br)), "W1_std": float(np.std(w1_br, ddof=1)),
1531
+ "W2_mean": float(np.mean(w2_br)), "W2_std": float(np.std(w2_br, ddof=1)),
1532
+ "MMD_mean": float(np.mean(mmd_br)), "MMD_std": float(np.std(mmd_br, ddof=1)),
1533
+ "W1_full_mean": float(np.mean(w1_br_full)), "W1_full_std": float(np.std(w1_br_full, ddof=1)),
1534
+ "W2_full_mean": float(np.mean(w2_br_full)), "W2_full_std": float(np.std(w2_br_full, ddof=1)),
1535
+ "MMD_full_mean": float(np.mean(mmd_br_full)), "MMD_full_std": float(np.std(mmd_br_full, ddof=1)),
1536
+ }
1537
+ self.log(f"test/W1_branch{i}", np.mean(w1_br), on_epoch=True)
1538
+ self.log(f"test/W1_full_branch{i}", np.mean(w1_br_full), on_epoch=True)
1539
+ print(f"Branch {i} — W1: {np.mean(w1_br):.6f}±{np.std(w1_br, ddof=1):.6f}, "
1540
+ f"W2: {np.mean(w2_br):.6f}±{np.std(w2_br, ddof=1):.6f}, "
1541
+ f"MMD: {np.mean(mmd_br):.6f}±{np.std(mmd_br, ddof=1):.6f}")
1542
+ print(f"Branch {i} (full) — W1: {np.mean(w1_br_full):.6f}±{np.std(w1_br_full, ddof=1):.6f}, "
1543
+ f"W2: {np.mean(w2_br_full):.6f}±{np.std(w2_br_full, ddof=1):.6f}, "
1544
+ f"MMD: {np.mean(mmd_br_full):.6f}±{np.std(mmd_br_full, ddof=1):.6f}")
1545
+
1546
+ # --- Final timepoint combined metrics ---
1547
+ gt_list = [torch.tensor(timepoint_data[k], dtype=torch.float32) for k in gt_keys]
1548
+ if len(gt_list) > 0 and len(all_endpoints) > 0:
1549
+ gt_all = torch.cat(gt_list, dim=0)
1550
+ pred_all = torch.cat([e.cpu() for e in all_endpoints], dim=0)
1551
+
1552
+ w1_trials, w2_trials, mmd_trials = [], [], []
1553
+ w1_trials_full, w2_trials_full, mmd_trials_full = [], [], []
1554
+ for trial in range(n_trials):
1555
+ n_min = min(pred_all.shape[0], gt_all.shape[0])
1556
+ perm_pred = torch.randperm(pred_all.shape[0])[:n_min]
1557
+ perm_gt = torch.randperm(gt_all.shape[0])[:n_min]
1558
+ # 2D metrics (PC1-PC2)
1559
+ m = compute_distribution_distances(pred_all[perm_pred, :2], gt_all[perm_gt, :2])
1560
+ w1_trials.append(m["W1"]); w2_trials.append(m["W2"]); mmd_trials.append(m["MMD"])
1561
+ # Full-dimensional metrics (all PCs)
1562
+ m_full = compute_distribution_distances(pred_all[perm_pred], gt_all[perm_gt])
1563
+ w1_trials_full.append(m_full["W1"]); w2_trials_full.append(m_full["W2"]); mmd_trials_full.append(m_full["MMD"])
1564
+
1565
+ w1_mean, w1_std = np.mean(w1_trials), np.std(w1_trials, ddof=1)
1566
+ w2_mean, w2_std = np.mean(w2_trials), np.std(w2_trials, ddof=1)
1567
+ mmd_mean, mmd_std = np.mean(mmd_trials), np.std(mmd_trials, ddof=1)
1568
+ w1_mean_f, w1_std_f = np.mean(w1_trials_full), np.std(w1_trials_full, ddof=1)
1569
+ w2_mean_f, w2_std_f = np.mean(w2_trials_full), np.std(w2_trials_full, ddof=1)
1570
+ mmd_mean_f, mmd_std_f = np.mean(mmd_trials_full), np.std(mmd_trials_full, ddof=1)
1571
+ self.log("test/W1_t7_combined", w1_mean, on_epoch=True)
1572
+ self.log("test/W2_t7_combined", w2_mean, on_epoch=True)
1573
+ self.log("test/MMD_t7_combined", mmd_mean, on_epoch=True)
1574
+ self.log("test/W1_full_t7_combined", w1_mean_f, on_epoch=True)
1575
+ self.log("test/W2_full_t7_combined", w2_mean_f, on_epoch=True)
1576
+ self.log("test/MMD_full_t7_combined", mmd_mean_f, on_epoch=True)
1577
+ metrics_dict['t7_combined'] = {
1578
+ "W1_mean": float(w1_mean), "W1_std": float(w1_std),
1579
+ "W2_mean": float(w2_mean), "W2_std": float(w2_std),
1580
+ "MMD_mean": float(mmd_mean), "MMD_std": float(mmd_std),
1581
+ "W1_full_mean": float(w1_mean_f), "W1_full_std": float(w1_std_f),
1582
+ "W2_full_mean": float(w2_mean_f), "W2_full_std": float(w2_std_f),
1583
+ "MMD_full_mean": float(mmd_mean_f), "MMD_full_std": float(mmd_std_f),
1584
+ "n_trials": n_trials,
1585
+ }
1586
+ print(f"\n=== Combined @ t7 ===")
1587
+ print(f"W1: {w1_mean:.6f} ± {w1_std:.6f}")
1588
+ print(f"W2: {w2_mean:.6f} ± {w2_std:.6f}")
1589
+ print(f"MMD: {mmd_mean:.6f} ± {mmd_std:.6f}")
1590
+ print(f"W1 (full): {w1_mean_f:.6f} ± {w1_std_f:.6f}")
1591
+ print(f"W2 (full): {w2_mean_f:.6f} ± {w2_std_f:.6f}")
1592
+ print(f"MMD (full): {mmd_mean_f:.6f} ± {mmd_std_f:.6f}")
1593
+
1594
+ # Create results directory structure
1595
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
1596
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
1597
+ figures_dir = f'{results_dir}/figures'
1598
+ os.makedirs(figures_dir, exist_ok=True)
1599
+
1600
+ # Save metrics to JSON
1601
+ metrics_path = f'{results_dir}/metrics.json'
1602
+ with open(metrics_path, 'w') as f:
1603
+ json.dump(metrics_dict, f, indent=2)
1604
+ print(f"Metrics saved to {metrics_path}")
1605
+
1606
+ # Save detailed metrics to CSV
1607
+ detailed_csv_path = f'{results_dir}/metrics_detailed.csv'
1608
+ with open(detailed_csv_path, 'w', newline='') as csvfile:
1609
+ writer = csv.writer(csvfile)
1610
+ writer.writerow(['Metric_Group',
1611
+ 'W1_Mean', 'W1_Std', 'W2_Mean', 'W2_Std', 'MMD_Mean', 'MMD_Std',
1612
+ 'W1_Full_Mean', 'W1_Full_Std', 'W2_Full_Mean', 'W2_Full_Std', 'MMD_Full_Mean', 'MMD_Full_Std'])
1613
+ for key in sorted(metrics_dict.keys()):
1614
+ m = metrics_dict[key]
1615
+ writer.writerow([key,
1616
+ f'{m.get("W1_mean", 0):.6f}', f'{m.get("W1_std", 0):.6f}',
1617
+ f'{m.get("W2_mean", 0):.6f}', f'{m.get("W2_std", 0):.6f}',
1618
+ f'{m.get("MMD_mean", 0):.6f}', f'{m.get("MMD_std", 0):.6f}',
1619
+ f'{m.get("W1_full_mean", 0):.6f}', f'{m.get("W1_full_std", 0):.6f}',
1620
+ f'{m.get("W2_full_mean", 0):.6f}', f'{m.get("W2_full_std", 0):.6f}',
1621
+ f'{m.get("MMD_full_mean", 0):.6f}', f'{m.get("MMD_full_std", 0):.6f}'])
1622
+ print(f"Detailed metrics CSV saved to {detailed_csv_path}")
1623
+
1624
+ # ===== Plot branches =====
1625
+ self._plot_veres_branches(all_trajs, timepoint_data, figures_dir, n_branches)
1626
+ self._plot_veres_combined(all_trajs, timepoint_data, figures_dir, n_branches)
1627
+
1628
+ print(f"Veres figures saved to {figures_dir}")
1629
+
1630
+ def _plot_veres_branches(self, all_trajs, timepoint_data, save_dir, n_branches):
1631
+ """Plot each branch separately in PCA space (PC1 vs PC2)."""
1632
+ branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9',
1633
+ '#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8',
1634
+ '#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8',
1635
+ '#64B5F6', '#C5E1A5']
1636
+
1637
+ # Project to first 2 PCs (data is already in PCA space)
1638
+ t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2]
1639
+ t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)]
1640
+
1641
+ # Slice trajectories to first 2 PCs
1642
+ trajs_2d = []
1643
+ for traj in all_trajs:
1644
+ trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2]
1645
+
1646
+ # Compute global axis limits
1647
+ all_coords = [t0_2d] + t7_2d
1648
+ for traj_2d in trajs_2d:
1649
+ all_coords.append(traj_2d.reshape(-1, 2))
1650
+
1651
+ all_coords = np.concatenate(all_coords, axis=0)
1652
+ x_min, x_max = all_coords[:, 0].min(), all_coords[:, 0].max()
1653
+ y_min, y_max = all_coords[:, 1].min(), all_coords[:, 1].max()
1654
+
1655
+ x_margin = 0.05 * (x_max - x_min)
1656
+ y_margin = 0.05 * (y_max - y_min)
1657
+ x_min -= x_margin
1658
+ x_max += x_margin
1659
+ y_min -= y_margin
1660
+ y_max += y_margin
1661
+
1662
+ for i, traj_2d in enumerate(trajs_2d):
1663
+ fig, ax = plt.subplots(figsize=(10, 8))
1664
+ c_end = branch_colors[i % len(branch_colors)]
1665
+
1666
+ # Plot timepoint background
1667
+ ax.scatter(t0_2d[:, 0], t0_2d[:, 1],
1668
+ c='#05009E', s=80, alpha=0.4, marker='x',
1669
+ label='t=0 cells', linewidth=1.5)
1670
+ ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1],
1671
+ c=c_end, s=80, alpha=0.4, marker='x',
1672
+ label=f't=7 (branch {i+1}) cells', linewidth=1.5)
1673
+
1674
+ # Plot continuous trajectories with LineCollection for speed
1675
+ cmap_colors = ["#05009E", "#A19EFF", c_end]
1676
+ cmap = LinearSegmentedColormap.from_list(f"veres_cmap_{i}", cmap_colors)
1677
+ n_time = traj_2d.shape[1]
1678
+ segments = []
1679
+ seg_colors = []
1680
+ color_vals = cmap(np.linspace(0, 1, n_time))
1681
+ for j in range(traj_2d.shape[0]):
1682
+ pts = traj_2d[j] # [T, 2]
1683
+ segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2]
1684
+ segments.append(segs)
1685
+ seg_colors.append(color_vals[:-1])
1686
+ segments = np.concatenate(segments, axis=0)
1687
+ seg_colors = np.concatenate(seg_colors, axis=0)
1688
+ lc = LineCollection(segments, colors=seg_colors, linewidths=2, alpha=0.8)
1689
+ ax.add_collection(lc)
1690
+
1691
+ # Start and end points
1692
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
1693
+ c='#05009E', s=30, marker='o', label='Trajectory start (t=0)',
1694
+ zorder=5, edgecolors='white', linewidth=1)
1695
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
1696
+ c=c_end, s=30, marker='o', label='Trajectory end (t=1)',
1697
+ zorder=5, edgecolors='white', linewidth=1)
1698
+
1699
+ ax.set_xlim(x_min, x_max)
1700
+ ax.set_ylim(y_min, y_max)
1701
+ ax.set_xlabel("PC 1", fontsize=12)
1702
+ ax.set_ylabel("PC 2", fontsize=12)
1703
+ ax.set_title(f"Branch {i+1}: Trajectories (PCA)", fontsize=14)
1704
+ ax.grid(True, alpha=0.3)
1705
+ ax.legend(loc='upper right', fontsize=9, frameon=False)
1706
+
1707
+ plt.tight_layout()
1708
+ plt.savefig(f'{save_dir}/{self.args.data_name}_branch{i+1}.png', dpi=300)
1709
+ plt.close()
1710
+
1711
+ def _plot_veres_combined(self, all_trajs, timepoint_data, save_dir, n_branches):
1712
+ """Plot all branches together in PCA space (PC1 vs PC2)."""
1713
+ branch_colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DFE6E9',
1714
+ '#74B9FF', '#A29BFE', '#FFB74D', '#AED581', '#F06292', '#BA68C8',
1715
+ '#4DB6AC', '#81C784', '#FFD54F', '#90A4AE', '#F48FB1', '#CE93D8',
1716
+ '#64B5F6', '#C5E1A5']
1717
+
1718
+ # Project to first 2 PCs (data is already in PCA space)
1719
+ t0_2d = timepoint_data['t0'].cpu().numpy()[:, :2]
1720
+ t7_2d = [timepoint_data[f't7_{i}'].cpu().numpy()[:, :2] for i in range(n_branches)]
1721
+
1722
+ # Slice trajectories to first 2 PCs
1723
+ trajs_2d = []
1724
+ for traj in all_trajs:
1725
+ trajs_2d.append(traj.cpu().numpy()[:, :, :2]) # [n_samples, n_time, 2]
1726
+
1727
+ # Compute axis limits from REAL CELLS ONLY
1728
+ all_coords_real = [t0_2d] + t7_2d
1729
+ all_coords_real = np.concatenate(all_coords_real, axis=0)
1730
+ x_min, x_max = all_coords_real[:, 0].min(), all_coords_real[:, 0].max()
1731
+ y_min, y_max = all_coords_real[:, 1].min(), all_coords_real[:, 1].max()
1732
+ x_margin = 0.05 * (x_max - x_min)
1733
+ y_margin = 0.05 * (y_max - y_min)
1734
+ x_min -= x_margin
1735
+ x_max += x_margin
1736
+ y_min -= y_margin
1737
+ y_max += y_margin
1738
+
1739
+ fig, ax = plt.subplots(figsize=(14, 12))
1740
+ ax.set_xlim(x_min, x_max)
1741
+ ax.set_ylim(y_min, y_max)
1742
+
1743
+ # Plot t=0 cells
1744
+ ax.scatter(t0_2d[:, 0], t0_2d[:, 1],
1745
+ c='#05009E', s=60, alpha=0.3, marker='x',
1746
+ label='t=0 cells', linewidth=1.5)
1747
+
1748
+ # Plot each branch's cells and trajectories
1749
+ for i, traj_2d in enumerate(trajs_2d):
1750
+ c_end = branch_colors[i % len(branch_colors)]
1751
+
1752
+ # Plot t=7 cells for this branch
1753
+ ax.scatter(t7_2d[i][:, 0], t7_2d[i][:, 1],
1754
+ c=c_end, s=60, alpha=0.3, marker='x',
1755
+ label=f't=7 (branch {i+1})', linewidth=1.5)
1756
+
1757
+ # Plot continuous trajectories with LineCollection for speed
1758
+ cmap_colors = ["#05009E", "#A19EFF", c_end]
1759
+ cmap = LinearSegmentedColormap.from_list(f"veres_combined_cmap_{i}", cmap_colors)
1760
+ n_time = traj_2d.shape[1]
1761
+ segments = []
1762
+ seg_colors = []
1763
+ color_vals = cmap(np.linspace(0, 1, n_time))
1764
+ for j in range(traj_2d.shape[0]):
1765
+ pts = traj_2d[j] # [T, 2]
1766
+ segs = np.stack([pts[:-1], pts[1:]], axis=1) # [T-1, 2, 2]
1767
+ segments.append(segs)
1768
+ seg_colors.append(color_vals[:-1])
1769
+ segments = np.concatenate(segments, axis=0)
1770
+ seg_colors = np.concatenate(seg_colors, axis=0)
1771
+ lc = LineCollection(segments, colors=seg_colors, linewidths=1.5, alpha=0.6)
1772
+ ax.add_collection(lc)
1773
+
1774
+ # Start and end points
1775
+ ax.scatter(traj_2d[:, 0, 0], traj_2d[:, 0, 1],
1776
+ c='#05009E', s=20, marker='o',
1777
+ zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7)
1778
+ ax.scatter(traj_2d[:, -1, 0], traj_2d[:, -1, 1],
1779
+ c=c_end, s=20, marker='o',
1780
+ zorder=5, edgecolors='white', linewidth=0.5, alpha=0.7)
1781
+
1782
+ ax.set_xlabel("PC 1", fontsize=14)
1783
+ ax.set_ylabel("PC 2", fontsize=14)
1784
+ ax.set_title(f"All {n_branches} Branch Trajectories (Veres) - PCA Projection",
1785
+ fontsize=16, weight='bold')
1786
+ ax.grid(True, alpha=0.3)
1787
+ ax.legend(loc='upper right', fontsize=10, frameon=False, ncol=2)
1788
+
1789
+ plt.tight_layout()
1790
+ plt.savefig(f'{save_dir}/{self.args.data_name}_combined.png', dpi=300)
1791
+ plt.close()
src/branch_flow_net_train.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import wandb
5
+ import matplotlib.pyplot as plt
6
+ import pytorch_lightning as pl
7
+ from torch.optim import AdamW
8
+ from torchmetrics.functional import mean_squared_error
9
+ from torchdyn.core import NeuralODE
10
+ from .networks.utils import flow_model_torch_wrapper
11
+ from .utils import wasserstein, plot_lidar
12
+ from .ema import EMA
13
+
14
+ class BranchFlowNetTrainBase(pl.LightningModule):
15
+ def __init__(
16
+ self,
17
+ flow_matcher,
18
+ flow_nets,
19
+ skipped_time_points=None,
20
+ ot_sampler=None,
21
+ args=None,
22
+ ):
23
+ super().__init__()
24
+ self.args = args
25
+
26
+ self.flow_matcher = flow_matcher
27
+ self.flow_nets = flow_nets # list of flow networks for each branch
28
+ self.ot_sampler = ot_sampler
29
+ self.skipped_time_points = skipped_time_points
30
+
31
+ self.optimizer_name = args.flow_optimizer
32
+ self.lr = args.flow_lr
33
+ self.weight_decay = args.flow_weight_decay
34
+ self.whiten = args.whiten
35
+ self.working_dir = args.working_dir
36
+
37
+ #branching
38
+ self.branches = len(flow_nets)
39
+
40
+ def forward(self, t, xt, branch_idx):
41
+ # output velocity given branch_idx
42
+ return self.flow_nets[branch_idx](t, xt)
43
+
44
+ def _compute_loss(self, main_batch):
45
+
46
+ x0s = [main_batch["x0"][0]]
47
+ w0s = [main_batch["x0"][1]]
48
+
49
+ x1s_list = []
50
+ w1s_list = []
51
+
52
+ if self.branches > 1:
53
+ for i in range(self.branches):
54
+ x1s_list.append([main_batch[f"x1_{i+1}"][0]])
55
+ w1s_list.append([main_batch[f"x1_{i+1}"][1]])
56
+ else:
57
+ x1s_list.append([main_batch["x1"][0]])
58
+ w1s_list.append([main_batch["x1"][1]])
59
+
60
+ assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
61
+
62
+ loss = 0
63
+ for branch_idx in range(self.branches):
64
+ ts, xts, uts = self._process_flow(x0s, x1s_list[branch_idx], branch_idx)
65
+
66
+ t = torch.cat(ts)
67
+ xt = torch.cat(xts)
68
+ ut = torch.cat(uts)
69
+ vt = self(t[:, None], xt, branch_idx)
70
+
71
+ loss += mean_squared_error(vt, ut)
72
+
73
+ return loss
74
+
75
+ def _process_flow(self, x0s, x1s, branch_idx):
76
+ ts, xts, uts = [], [], []
77
+ t_start = self.timesteps[0]
78
+
79
+ for i, (x0, x1) in enumerate(zip(x0s, x1s)):
80
+
81
+ x0, x1 = torch.squeeze(x0), torch.squeeze(x1)
82
+
83
+ if self.ot_sampler is not None:
84
+ x0, x1 = self.ot_sampler.sample_plan(
85
+ x0,
86
+ x1,
87
+ replace=True,
88
+ )
89
+ if self.skipped_time_points and i + 1 >= self.skipped_time_points[0]:
90
+ t_start_next = self.timesteps[i + 2]
91
+ else:
92
+ t_start_next = self.timesteps[i + 1]
93
+
94
+ # edit to sample from correct flow matcher
95
+ t, xt, ut = self.flow_matcher.sample_location_and_conditional_flow(
96
+ x0, x1, t_start, t_start_next, branch_idx
97
+ )
98
+
99
+ ts.append(t)
100
+
101
+ xts.append(xt)
102
+ uts.append(ut)
103
+ t_start = t_start_next
104
+ return ts, xts, uts
105
+
106
+ def training_step(self, batch, batch_idx):
107
+ # Handle both dict and tuple batch formats from CombinedLoader
108
+ if isinstance(batch, (list, tuple)):
109
+ batch = batch[0]
110
+ if isinstance(batch, dict) and "train_samples" in batch:
111
+ main_batch = batch["train_samples"]
112
+ if isinstance(main_batch, tuple):
113
+ main_batch = main_batch[0]
114
+ else:
115
+ # Fallback
116
+ main_batch = batch.get("train_samples", batch)
117
+
118
+ print("Main batch length")
119
+ print(len(main_batch["x0"]))
120
+
121
+ # edited to simulate 100 steps
122
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
123
+ loss = self._compute_loss(main_batch)
124
+ if self.flow_matcher.alpha != 0:
125
+ self.log(
126
+ "FlowNet/mean_geopath_cfm",
127
+ (self.flow_matcher.geopath_net_output.abs().mean()),
128
+ on_step=False,
129
+ on_epoch=True,
130
+ prog_bar=True,
131
+ )
132
+
133
+ self.log(
134
+ "FlowNet/train_loss_cfm",
135
+ loss,
136
+ on_step=False,
137
+ on_epoch=True,
138
+ prog_bar=True,
139
+ logger=True,
140
+ )
141
+
142
+
143
+ return loss
144
+
145
+ def validation_step(self, batch, batch_idx):
146
+ # Handle both dict and tuple batch formats from CombinedLoader
147
+ if isinstance(batch, (list, tuple)):
148
+ batch = batch[0]
149
+ if isinstance(batch, dict) and "val_samples" in batch:
150
+ main_batch = batch["val_samples"]
151
+ if isinstance(main_batch, tuple):
152
+ main_batch = main_batch[0]
153
+ else:
154
+ # Fallback
155
+ main_batch = batch.get("val_samples", batch)
156
+
157
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
158
+ val_loss = self._compute_loss(main_batch)
159
+ self.log(
160
+ "FlowNet/val_loss_cfm",
161
+ val_loss,
162
+ on_step=False,
163
+ on_epoch=True,
164
+ prog_bar=True,
165
+ logger=True,
166
+ )
167
+ return val_loss
168
+
169
+ def optimizer_step(self, *args, **kwargs):
170
+ super().optimizer_step(*args, **kwargs)
171
+
172
+ for net in self.flow_nets:
173
+ if isinstance(net, EMA):
174
+ net.update_ema()
175
+
176
+ def configure_optimizers(self):
177
+ if self.optimizer_name == "adamw":
178
+ optimizer = AdamW(
179
+ self.parameters(),
180
+ lr=self.lr,
181
+ weight_decay=self.weight_decay,
182
+ )
183
+ elif self.optimizer_name == "adam":
184
+ optimizer = torch.optim.Adam(
185
+ self.parameters(),
186
+ lr=self.lr,
187
+ )
188
+
189
+ return optimizer
190
+
191
+
192
+ class FlowNetTrainTrajectory(BranchFlowNetTrainBase):
193
+ def test_step(self, batch, batch_idx):
194
+ data_type = self.args.data_type
195
+ node = NeuralODE(
196
+ flow_model_torch_wrapper(self.flow_nets),
197
+ solver="euler",
198
+ sensitivity="adjoint",
199
+ atol=1e-5,
200
+ rtol=1e-5,
201
+ )
202
+
203
+ t_exclude = self.skipped_time_points[0] if self.skipped_time_points else None
204
+ if t_exclude is not None:
205
+ traj = node.trajectory(
206
+ batch[t_exclude - 1],
207
+ t_span=torch.linspace(
208
+ self.timesteps[t_exclude - 1], self.timesteps[t_exclude], 101
209
+ ),
210
+ )
211
+ X_mid_pred = traj[-1]
212
+ traj = node.trajectory(
213
+ batch[t_exclude - 1],
214
+ t_span=torch.linspace(
215
+ self.timesteps[t_exclude - 1],
216
+ self.timesteps[t_exclude + 1],
217
+ 101,
218
+ ),
219
+ )
220
+
221
+ EMD = wasserstein(X_mid_pred, batch[t_exclude], p=1)
222
+ self.final_EMD = EMD
223
+
224
+ self.log("test_EMD", EMD, on_step=False, on_epoch=True, prog_bar=True)
225
+
226
+ class FlowNetTrainCell(BranchFlowNetTrainBase):
227
+ def test_step(self, batch, batch_idx):
228
+ x0 = batch[0]["test_samples"][0]["x0"][0] # [B, D]
229
+ dataset_points = batch[0]["test_samples"][0]["dataset"][0] # full dataset, [N, D]
230
+ t_span = torch.linspace(0, 1, 101)
231
+
232
+ all_trajs = []
233
+
234
+ for i, flow_net in enumerate(self.flow_nets):
235
+ node = NeuralODE(
236
+ flow_model_torch_wrapper(flow_net),
237
+ solver="euler",
238
+ sensitivity="adjoint",
239
+ )
240
+
241
+ with torch.no_grad():
242
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
243
+
244
+ if self.whiten:
245
+ traj_shape = traj.shape
246
+ traj = traj.reshape(-1, traj.shape[-1])
247
+ traj = self.trainer.datamodule.scaler.inverse_transform(
248
+ traj.cpu().detach().numpy()
249
+ ).reshape(traj_shape)
250
+ dataset_points = self.trainer.datamodule.scaler.inverse_transform(
251
+ dataset_points.cpu().detach().numpy()
252
+ )
253
+
254
+ traj = torch.tensor(traj)
255
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
256
+ all_trajs.append(traj)
257
+
258
+ dataset_2d = dataset_points[:, :2] if isinstance(dataset_points, torch.Tensor) else dataset_points[:, :2]
259
+
260
+ # ===== Plot all 2D trajectories together with dataset and start/end points =====
261
+ fig, ax = plt.subplots(figsize=(6, 5))
262
+ dataset_2d = dataset_2d.cpu().numpy()
263
+ ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
264
+ for traj in all_trajs:
265
+ traj_2d = traj[..., :2] # [B, T, 2]
266
+ for i in range(traj_2d.shape[0]):
267
+ ax.plot(traj_2d[i, :, 0], traj_2d[i, :, 1], alpha=0.8, zorder=2)
268
+ ax.scatter(traj_2d[i, 0, 0], traj_2d[i, 0, 1], c='green', s=10, label="t=0" if i == 0 else "", zorder=3)
269
+ ax.scatter(traj_2d[i, -1, 0], traj_2d[i, -1, 1], c='red', s=10, label="t=1" if i == 0 else "", zorder=3)
270
+
271
+ ax.set_title("All Branch Trajectories (2D) with Dataset")
272
+ ax.set_xlabel("x")
273
+ ax.set_ylabel("y")
274
+ plt.axis("equal")
275
+ handles, labels = ax.get_legend_handles_labels()
276
+ if labels:
277
+ ax.legend()
278
+
279
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
280
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
281
+ save_path = os.path.join(results_dir, 'figures')
282
+
283
+ os.makedirs(save_path, exist_ok=True)
284
+ plt.savefig(f'{save_path}/{self.args.data_name}_all_branches.png', dpi=300)
285
+ plt.close()
286
+
287
+ # ===== Plot each 2D trajectory separately with dataset and endpoints =====
288
+ for i, traj in enumerate(all_trajs):
289
+ traj_2d = traj[..., :2]
290
+ fig, ax = plt.subplots(figsize=(6, 5))
291
+ ax.scatter(dataset_2d[:, 0], dataset_2d[:, 1], c="gray", s=1, alpha=0.5, label="Dataset", zorder=1)
292
+ for j in range(traj_2d.shape[0]):
293
+ ax.plot(traj_2d[j, :, 0], traj_2d[j, :, 1], alpha=0.9, zorder=2)
294
+ ax.scatter(traj_2d[j, 0, 0], traj_2d[j, 0, 1], c='green', s=12, label="t=0" if j == 0 else "", zorder=3)
295
+ ax.scatter(traj_2d[j, -1, 0], traj_2d[j, -1, 1], c='red', s=12, label="t=1" if j == 0 else "", zorder=3)
296
+
297
+ ax.set_title(f"Branch {i + 1} Trajectories (2D) with Dataset")
298
+ ax.set_xlabel("x")
299
+ ax.set_ylabel("y")
300
+ plt.axis("equal")
301
+ handles, labels = ax.get_legend_handles_labels()
302
+ if labels:
303
+ ax.legend()
304
+ plt.savefig(f'{save_path}/{self.args.data_name}_branch_{i + 1}.png', dpi=300)
305
+ plt.close()
306
+
307
+ class FlowNetTrainLidar(BranchFlowNetTrainBase):
308
+ def test_step(self, batch, batch_idx):
309
+ # Handle both tuple and dict batch formats from CombinedLoader
310
+ if isinstance(batch, dict):
311
+ main_batch = batch["test_samples"][0]
312
+ metric_batch = batch["metric_samples"][0]
313
+ else:
314
+ # batch is a tuple: (test_samples, metric_samples)
315
+ main_batch = batch[0][0]
316
+ metric_batch = batch[1][0]
317
+
318
+ x0 = main_batch["x0"][0] # [B, D]
319
+ cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
320
+ t_span = torch.linspace(0, 1, 101)
321
+
322
+ all_trajs = []
323
+
324
+ for i, flow_net in enumerate(self.flow_nets):
325
+ node = NeuralODE(
326
+ flow_model_torch_wrapper(flow_net),
327
+ solver="euler",
328
+ sensitivity="adjoint",
329
+ )
330
+
331
+ with torch.no_grad():
332
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
333
+
334
+ if self.whiten:
335
+ traj_shape = traj.shape
336
+ traj = traj.reshape(-1, 3)
337
+ traj = self.trainer.datamodule.scaler.inverse_transform(
338
+ traj.cpu().detach().numpy()
339
+ ).reshape(traj_shape)
340
+
341
+ traj = torch.tensor(traj)
342
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
343
+ all_trajs.append(traj)
344
+
345
+ # Inverse-transform the point cloud once
346
+ if self.whiten:
347
+ cloud_points = torch.tensor(
348
+ self.trainer.datamodule.scaler.inverse_transform(
349
+ cloud_points.cpu().detach().numpy()
350
+ )
351
+ )
352
+
353
+ # Create directory for saving figures
354
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
355
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
356
+ lidar_fig_dir = os.path.join(results_dir, 'figures')
357
+ os.makedirs(lidar_fig_dir, exist_ok=True)
358
+
359
+ # ===== Plot all trajectories together =====
360
+ fig = plt.figure(figsize=(6, 5))
361
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
362
+ ax.view_init(elev=30, azim=-115, roll=0)
363
+ for i, traj in enumerate(all_trajs):
364
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
365
+ plt.savefig(os.path.join(lidar_fig_dir, 'lidar_all_branches.png'), dpi=300)
366
+ plt.close()
367
+
368
+ # ===== Plot each trajectory separately =====
369
+ for i, traj in enumerate(all_trajs):
370
+ fig = plt.figure(figsize=(6, 5))
371
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
372
+ ax.view_init(elev=30, azim=-115, roll=0)
373
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
374
+ plt.savefig(os.path.join(lidar_fig_dir, f'lidar_branch_{i + 1}.png'), dpi=300)
375
+ plt.close()
src/branch_growth_net_train.py ADDED
@@ -0,0 +1,994 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import wandb
5
+ import matplotlib.pyplot as plt
6
+ import pytorch_lightning as pl
7
+ from torch.optim import AdamW
8
+ from torchmetrics.functional import mean_squared_error
9
+ from torchdyn.core import NeuralODE
10
+ import numpy as np
11
+ import lpips
12
+ from .networks.utils import flow_model_torch_wrapper
13
+ from .utils import plot_lidar
14
+ from .ema import EMA
15
+ from torchdiffeq import odeint as odeint2
16
+ from .losses.energy_loss import EnergySolver, ReconsLoss
17
+
18
+ class GrowthNetTrain(pl.LightningModule):
19
+ def __init__(
20
+ self,
21
+ flow_nets,
22
+ growth_nets,
23
+ skipped_time_points=None,
24
+ ot_sampler=None,
25
+ args=None,
26
+
27
+ state_cost=None,
28
+ data_manifold_metric=None,
29
+
30
+ joint = False
31
+ ):
32
+ super().__init__()
33
+ #self.save_hyperparameters()
34
+ self.flow_nets = flow_nets
35
+
36
+ if not joint:
37
+ for param in self.flow_nets.parameters():
38
+ param.requires_grad = False
39
+
40
+ self.growth_nets = growth_nets # list of growth networks for each branch
41
+
42
+ self.ot_sampler = ot_sampler
43
+ self.skipped_time_points = skipped_time_points
44
+
45
+ self.optimizer_name = args.growth_optimizer
46
+ self.lr = args.growth_lr
47
+ self.weight_decay = args.growth_weight_decay
48
+ self.whiten = args.whiten
49
+ self.working_dir = args.working_dir
50
+
51
+ self.args = args
52
+
53
+ #branching
54
+ self.state_cost = state_cost
55
+ self.data_manifold_metric = data_manifold_metric
56
+ self.branches = len(growth_nets)
57
+ self.metric_clusters = args.metric_clusters
58
+
59
+ self.recons_loss = ReconsLoss()
60
+
61
+ # loss weights
62
+ self.lambda_energy = args.lambda_energy
63
+ self.lambda_mass = args.lambda_mass
64
+ self.lambda_match = args.lambda_match
65
+ self.lambda_recons = args.lambda_recons
66
+
67
+ self.joint = joint
68
+
69
+ def forward(self, t, xt, branch_idx):
70
+ # output growth rate given branch_idx
71
+ return self.growth_nets[branch_idx](t, xt)
72
+
73
+ def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False):
74
+ x0s = main_batch["x0"][0]
75
+ w0s = main_batch["x0"][1]
76
+ x1s_list = []
77
+ w1s_list = []
78
+
79
+ if self.branches > 1:
80
+ for i in range(self.branches):
81
+ x1s_list.append([main_batch[f"x1_{i+1}"][0]])
82
+ w1s_list.append([main_batch[f"x1_{i+1}"][1]])
83
+ else:
84
+ x1s_list.append([main_batch["x1"][0]])
85
+ w1s_list.append([main_batch["x1"][1]])
86
+
87
+ if self.args.manifold:
88
+ #changed
89
+ if self.metric_clusters == 7 and self.branches == 6:
90
+ # Weinreb 6-branch scenario: cluster 0 (root) → clusters 1-6 (6 branches)
91
+ branch_sample_pairs = [
92
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
93
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
94
+ (metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 (branch 3)
95
+ (metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 (branch 4)
96
+ (metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 (branch 5)
97
+ (metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 (branch 6)
98
+ ]
99
+ elif self.metric_clusters == 4:
100
+ branch_sample_pairs = [
101
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
102
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
103
+ (metric_samples_batch[0], metric_samples_batch[3]),
104
+ ]
105
+ elif self.metric_clusters == 3:
106
+ branch_sample_pairs = [
107
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
108
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
109
+ ]
110
+ elif self.metric_clusters == 2 and self.branches == 2:
111
+ branch_sample_pairs = [
112
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
113
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
114
+ ]
115
+ elif self.metric_clusters == 2:
116
+ # For any number of branches with 2 metric clusters (initial vs remaining)
117
+ # All branches use the same metric cluster pair
118
+ branch_sample_pairs = [
119
+ (metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches
120
+ ] * self.branches
121
+ else:
122
+ branch_sample_pairs = [
123
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
124
+ ]
125
+
126
+ batch_size = x0s.shape[0]
127
+
128
+ assert len(x1s_list) == self.branches, "Mismatch between x1s_list and expected branches"
129
+
130
+ energy_loss = [0.] * self.branches
131
+ mass_loss = 0.
132
+ neg_weight_penalty = 0.
133
+ match_loss = [0.] * self.branches
134
+ recons_loss = [0.] * self.branches
135
+
136
+ dtype = x0s[0].dtype
137
+ #w0s = torch.zeros((batch_size, 1), dtype=dtype)
138
+ m0s = torch.zeros_like(w0s, dtype=dtype)
139
+ start_state = (x0s, w0s, m0s)
140
+
141
+ xt = [x0s.clone() for _ in range(self.branches)]
142
+ w0_branch = torch.zeros_like(w0s, dtype=dtype)
143
+ w0_branches = []
144
+ w0_branches.append(w0s)
145
+ for _ in range(self.branches - 1):
146
+ w0_branches.append(w0_branch)
147
+ #w0_branches = [w0_branch.clone() for _ in range(self.branches - 1)]
148
+ wt = w0_branches
149
+
150
+ mt = [m0s.clone() for _ in range(self.branches)]
151
+
152
+ # loop through timesteps
153
+ for step_idx, (s, t) in enumerate(zip(self.timesteps[:-1], self.timesteps[1:])):
154
+ time = torch.Tensor([s, t])
155
+
156
+ total_w_t = 0
157
+ # loop through branches
158
+ for i in range(self.branches):
159
+
160
+ if self.args.manifold:
161
+ start_samples, end_samples = branch_sample_pairs[i]
162
+ samples = torch.cat([start_samples, end_samples], dim=0)
163
+ else:
164
+ samples = None
165
+
166
+ # initialize weight and energy
167
+ start_state = (xt[i], wt[i], mt[i])
168
+
169
+ # loop over timesteps
170
+ xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx)
171
+
172
+ # placeholders for next state
173
+ xt_last = xt_next[-1]
174
+ wt_last = wt_next[-1]
175
+ mt_last = mt_next[-1]
176
+
177
+ total_w_t += wt_last
178
+
179
+ energy_loss[i] += (mt_last - mt[i])
180
+ neg_weight_penalty += torch.relu(-wt_last).sum()
181
+
182
+ # update branch state
183
+ xt[i] = xt_last.clone().detach()
184
+ wt[i] = wt_last.clone().detach()
185
+ mt[i] = mt_last.clone().detach()
186
+
187
+ # calculate mass loss from all branches
188
+ target = torch.ones_like(total_w_t)
189
+ mass_loss += mean_squared_error(total_w_t, target)
190
+
191
+ # calculate loss that matches final weights
192
+ for i in range(self.branches):
193
+ match_loss[i] = mean_squared_error(wt[i], w1s_list[i][0])
194
+ # compute reconstruction loss
195
+ recons_loss[i] = self.recons_loss(xt[i], x1s_list[i][0])
196
+
197
+ # average across time steps (loop runs len(timesteps)-1 times)
198
+ mass_loss = mass_loss / max(len(self.timesteps) - 1, 1)
199
+
200
+
201
+ # Weighted mean across branches (inversely weighted by cluster size)
202
+ # Get cluster sizes from datamodule if available
203
+ if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'):
204
+ cluster_sizes = self.trainer.datamodule.cluster_sizes
205
+ max_size = max(cluster_sizes)
206
+ # Inverse weighting: smaller clusters get higher weight
207
+ branch_weights = torch.tensor([max_size / size for size in cluster_sizes],
208
+ dtype=energy_loss[0].dtype, device=energy_loss[0].device)
209
+ # Normalize weights to sum to num_branches for fair comparison
210
+ branch_weights = branch_weights * self.branches / branch_weights.sum()
211
+
212
+ energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss]) * branch_weights)
213
+ match_loss = torch.mean(torch.stack(match_loss) * branch_weights)
214
+ recons_loss = torch.mean(torch.stack(recons_loss) * branch_weights)
215
+ else:
216
+ # Fallback to uniform weighting
217
+ energy_loss = torch.mean(torch.stack([e.mean() for e in energy_loss]))
218
+ match_loss = torch.mean(torch.stack(match_loss))
219
+ recons_loss = torch.mean(torch.stack(recons_loss))
220
+
221
+ loss = (self.lambda_energy * energy_loss) + (self.lambda_mass * (mass_loss + neg_weight_penalty)) + (self.lambda_match * match_loss) \
222
+ + (self.lambda_recons * recons_loss)
223
+
224
+ if self.joint:
225
+ if validation:
226
+ self.log("JointTrain/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
227
+ self.log("JointTrain/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
228
+ self.log("JointTrain/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
229
+ self.log("JointTrain/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
230
+ self.log("JointTrain/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
231
+ else:
232
+ self.log("JointTrain/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
233
+ self.log("JointTrain/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
234
+ self.log("JointTrain/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
235
+ self.log("JointTrain/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
236
+ self.log("JointTrain/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
237
+ else:
238
+ if validation:
239
+ self.log("GrowthNet/val_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
240
+ self.log("GrowthNet/val_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
241
+ self.log("GrowthNet/val_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
242
+ self.log("GrowthNet/val_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
243
+ self.log("GrowthNet/val_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
244
+ else:
245
+ self.log("GrowthNet/train_mass_loss", mass_loss, on_step=False, on_epoch=True, prog_bar=True)
246
+ self.log("GrowthNet/train_neg_penalty_loss", neg_weight_penalty, on_step=False, on_epoch=True, prog_bar=True)
247
+ self.log("GrowthNet/train_match_loss", match_loss, on_step=False, on_epoch=True, prog_bar=True)
248
+ self.log("GrowthNet/train_energy_loss", energy_loss, on_step=False, on_epoch=True, prog_bar=True)
249
+ self.log("GrowthNet/train_recons_loss", recons_loss, on_step=False, on_epoch=True, prog_bar=True)
250
+
251
+ return loss
252
+
253
+ def take_step(self, t, start_state, branch_idx, samples=None, timestep_idx=0):
254
+
255
+ flow_net = self.flow_nets[branch_idx]
256
+ growth_net = self.growth_nets[branch_idx]
257
+
258
+
259
+ x_t, w_t, m_t = odeint2(EnergySolver(flow_net, growth_net, self.state_cost, self.data_manifold_metric, samples, timestep_idx), start_state, t, options=dict(step_size=0.1),method='euler')
260
+
261
+ return x_t, w_t, m_t
262
+
263
+ def training_step(self, batch, batch_idx):
264
+ if isinstance(batch, (list, tuple)):
265
+ batch = batch[0]
266
+ if isinstance(batch, dict) and "train_samples" in batch:
267
+ main_batch = batch["train_samples"]
268
+ metric_batch = batch["metric_samples"]
269
+ if isinstance(main_batch, tuple):
270
+ main_batch = main_batch[0]
271
+ if isinstance(metric_batch, tuple):
272
+ metric_batch = metric_batch[0]
273
+ else:
274
+ # Fallback
275
+ main_batch = batch.get("train_samples", batch)
276
+ metric_batch = batch.get("metric_samples", [])
277
+
278
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
279
+ loss = self._compute_loss(main_batch, metric_batch, validation=False)
280
+
281
+ if self.joint:
282
+ self.log(
283
+ "JointTrain/train_loss",
284
+ loss,
285
+ on_step=False,
286
+ on_epoch=True,
287
+ prog_bar=True,
288
+ logger=True,
289
+ )
290
+ else:
291
+ self.log(
292
+ "GrowthNet/train_loss",
293
+ loss,
294
+ on_step=False,
295
+ on_epoch=True,
296
+ prog_bar=True,
297
+ logger=True,
298
+ )
299
+
300
+ return loss
301
+
302
+ def validation_step(self, batch, batch_idx):
303
+ if isinstance(batch, (list, tuple)):
304
+ batch = batch[0]
305
+ if isinstance(batch, dict) and "val_samples" in batch:
306
+ main_batch = batch["val_samples"]
307
+ metric_batch = batch["metric_samples"]
308
+ if isinstance(main_batch, tuple):
309
+ main_batch = main_batch[0]
310
+ if isinstance(metric_batch, tuple):
311
+ metric_batch = metric_batch[0]
312
+ else:
313
+ # Fallback
314
+ main_batch = batch.get("val_samples", batch)
315
+ metric_batch = batch.get("metric_samples", [])
316
+
317
+ self.timesteps = torch.linspace(0.0, 1.0, len(main_batch["x0"])).tolist()
318
+ val_loss = self._compute_loss(main_batch, metric_batch, validation=True)
319
+
320
+ if self.joint:
321
+ self.log(
322
+ "JointTrain/val_loss",
323
+ val_loss,
324
+ on_step=False,
325
+ on_epoch=True,
326
+ prog_bar=True,
327
+ logger=True,
328
+ )
329
+ else:
330
+ self.log(
331
+ "GrowthNet/val_loss",
332
+ val_loss,
333
+ on_step=False,
334
+ on_epoch=True,
335
+ prog_bar=True,
336
+ logger=True,
337
+ )
338
+ return val_loss
339
+
340
+ def optimizer_step(self, *args, **kwargs):
341
+ super().optimizer_step(*args, **kwargs)
342
+ for net in self.growth_nets:
343
+ if isinstance(net, EMA):
344
+ net.update_ema()
345
+ if self.joint:
346
+ for net in self.flow_nets:
347
+ if isinstance(net, EMA):
348
+ net.update_ema()
349
+
350
+ def configure_optimizers(self):
351
+ params = []
352
+ for net in self.growth_nets:
353
+ params += list(net.parameters())
354
+
355
+ if self.joint:
356
+ for net in self.flow_nets:
357
+ params += list(net.parameters())
358
+
359
+ if self.optimizer_name == "adamw":
360
+ optimizer = AdamW(
361
+ params,
362
+ lr=self.lr,
363
+ weight_decay=self.weight_decay,
364
+ )
365
+ elif self.optimizer_name == "adam":
366
+ optimizer = torch.optim.Adam(
367
+ params,
368
+ lr=self.lr,
369
+ )
370
+
371
+ return optimizer
372
+
373
+ @torch.no_grad()
374
+ def get_mass_and_position(self, main_batch, metric_samples_batch=None):
375
+ if isinstance(main_batch, dict):
376
+ main_batch = main_batch
377
+ else:
378
+ main_batch = main_batch[0]
379
+
380
+ x0s = main_batch["x0"][0]
381
+ w0s = main_batch["x0"][1]
382
+
383
+ if self.args.manifold:
384
+ if self.metric_clusters == 4:
385
+ branch_sample_pairs = [
386
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
387
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
388
+ (metric_samples_batch[0], metric_samples_batch[3]),
389
+ ]
390
+ elif self.metric_clusters == 3:
391
+ branch_sample_pairs = [
392
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
393
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
394
+ ]
395
+ elif self.metric_clusters == 2 and self.branches == 2:
396
+ branch_sample_pairs = [
397
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
398
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
399
+ ]
400
+ elif self.metric_clusters == 2:
401
+ # For any number of branches with 2 metric clusters (initial vs remaining)
402
+ # All branches use the same metric cluster pair
403
+ branch_sample_pairs = [
404
+ (metric_samples_batch[0], metric_samples_batch[1]) # x0 → all branches
405
+ ] * self.branches
406
+ else:
407
+ branch_sample_pairs = [
408
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
409
+ ]
410
+
411
+ batch_size = x0s.shape[0]
412
+ dtype = x0s[0].dtype
413
+
414
+ m0s = torch.zeros_like(w0s, dtype=dtype)
415
+ xt = [x0s.clone() for _ in range(self.branches)]
416
+
417
+ w0_branch = torch.zeros_like(w0s, dtype=dtype)
418
+ w0_branches = []
419
+ w0_branches.append(w0s)
420
+ for _ in range(self.branches - 1):
421
+ w0_branches.append(w0_branch)
422
+
423
+ wt = w0_branches
424
+ mt = [m0s.clone() for _ in range(self.branches)]
425
+
426
+ time_points = []
427
+ mass_over_time = [[] for _ in range(self.branches)]
428
+ energy_over_time = [[] for _ in range(self.branches)]
429
+ # record per-sample weights at each time for each branch (to allow OT with per-sample masses)
430
+ weights_over_time = [[] for _ in range(self.branches)]
431
+ all_trajs = [[] for _ in range(self.branches)]
432
+
433
+ t_span = torch.linspace(0, 1, 101)
434
+ for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])):
435
+ time_points.append(t.item())
436
+ time = torch.Tensor([s, t])
437
+
438
+ for i in range(self.branches):
439
+ if self.args.manifold:
440
+ start_samples, end_samples = branch_sample_pairs[i]
441
+ samples = torch.cat([start_samples, end_samples], dim=0)
442
+ else:
443
+ samples = None
444
+
445
+ start_state = (xt[i], wt[i], mt[i])
446
+ xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx)
447
+
448
+ xt[i] = xt_next[-1].clone().detach()
449
+ wt[i] = wt_next[-1].clone().detach()
450
+ mt[i] = mt_next[-1].clone().detach()
451
+
452
+ all_trajs[i].append(xt[i].clone().detach())
453
+ mass_over_time[i].append(wt[i].mean().item())
454
+ energy_over_time[i].append(mt[i].mean().item())
455
+ # store per-sample weights (clone to detach from graph)
456
+ try:
457
+ weights_over_time[i].append(wt[i].clone().detach())
458
+ except Exception:
459
+ # fallback: store mean as singleton tensor
460
+ weights_over_time[i].append(torch.tensor(wt[i].mean().item()).unsqueeze(0))
461
+
462
+ return time_points, xt, all_trajs, mass_over_time, energy_over_time, weights_over_time
463
+
464
+ @torch.no_grad()
465
+ def _plot_mass_and_energy(self, main_batch, metric_samples_batch=None, save_dir=None):
466
+ x0s = main_batch["x0"][0]
467
+ w0s = main_batch["x0"][1]
468
+
469
+ if self.args.manifold:
470
+ if self.metric_clusters == 7 and self.branches == 6:
471
+ # Weinreb 6-branch scenario: cluster 0 (root) → clusters 1-6 (6 branches)
472
+ branch_sample_pairs = [
473
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
474
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
475
+ (metric_samples_batch[0], metric_samples_batch[3]), # x0 → x1_3 (branch 3)
476
+ (metric_samples_batch[0], metric_samples_batch[4]), # x0 → x1_4 (branch 4)
477
+ (metric_samples_batch[0], metric_samples_batch[5]), # x0 → x1_5 (branch 5)
478
+ (metric_samples_batch[0], metric_samples_batch[6]), # x0 → x1_6 (branch 6)
479
+ ]
480
+ elif self.metric_clusters == 4:
481
+ branch_sample_pairs = [
482
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
483
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
484
+ (metric_samples_batch[0], metric_samples_batch[3]),
485
+ ]
486
+ elif self.metric_clusters == 3:
487
+ branch_sample_pairs = [
488
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
489
+ (metric_samples_batch[0], metric_samples_batch[2]), # x0 → x1_2 (branch 2)
490
+ ]
491
+ elif self.metric_clusters == 2 and self.branches == 2:
492
+ branch_sample_pairs = [
493
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
494
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_2 (branch 2)
495
+ ]
496
+ else:
497
+ branch_sample_pairs = [
498
+ (metric_samples_batch[0], metric_samples_batch[1]), # x0 → x1_1 (branch 1)
499
+ ]
500
+
501
+ batch_size = x0s.shape[0]
502
+ dtype = x0s[0].dtype
503
+
504
+ m0s = torch.zeros_like(w0s, dtype=dtype)
505
+ xt = [x0s.clone() for _ in range(self.branches)]
506
+
507
+ w0_branch = torch.zeros_like(w0s, dtype=dtype)
508
+ w0_branches = []
509
+ w0_branches.append(w0s)
510
+ for _ in range(self.branches - 1):
511
+ w0_branches.append(w0_branch)
512
+
513
+ wt = w0_branches
514
+ mt = [m0s.clone() for _ in range(self.branches)]
515
+
516
+ time_points = []
517
+ mass_over_time = [[] for _ in range(self.branches)]
518
+ energy_over_time = [[] for _ in range(self.branches)]
519
+
520
+ t_span = torch.linspace(0, 1, 101)
521
+ for step_idx, (s, t) in enumerate(zip(t_span[:-1], t_span[1:])):
522
+ time_points.append(t.item())
523
+ time = torch.Tensor([s, t])
524
+
525
+ for i in range(self.branches):
526
+ if self.args.manifold:
527
+ start_samples, end_samples = branch_sample_pairs[i]
528
+ samples = torch.cat([start_samples, end_samples], dim=0)
529
+ else:
530
+ samples = None
531
+
532
+ start_state = (xt[i], wt[i], mt[i])
533
+ xt_next, wt_next, mt_next = self.take_step(time, start_state, i, samples, timestep_idx=step_idx)
534
+
535
+ xt[i] = xt_next[-1].clone().detach()
536
+ wt[i] = wt_next[-1].clone().detach()
537
+ mt[i] = mt_next[-1].clone().detach()
538
+
539
+ mass_over_time[i].append(wt[i].mean().item())
540
+ energy_over_time[i].append(mt[i].mean().item())
541
+
542
+ if save_dir is None:
543
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
544
+ save_dir = os.path.join(self.args.working_dir, 'results', run_name, 'figures')
545
+ os.makedirs(save_dir, exist_ok=True)
546
+
547
+ # Use tab10 colormap to get visually distinct colors
548
+ if self.args.branches == 3:
549
+ branch_colors = ['#9793F8', '#50B2D7', '#D577FF'] # tuple of RGBs
550
+ else:
551
+ branch_colors = ['#50B2D7', '#D577FF'] # tuple of RGBs
552
+
553
+ # --- Plot Mass ---
554
+ plt.figure(figsize=(8, 5))
555
+ for i in range(self.branches):
556
+ color = branch_colors[i]
557
+ plt.plot(time_points, mass_over_time[i], color=color, linewidth=2.5, label=f"Mass Branch {i}")
558
+ plt.xlabel("Time")
559
+ plt.ylabel("Mass")
560
+ plt.title("Mass Evolution per Branch")
561
+ plt.legend()
562
+ plt.grid(True)
563
+ if self.joint:
564
+ mass_path = os.path.join(save_dir, f"{self.args.data_name}_joint_mass.png")
565
+ else:
566
+ mass_path = os.path.join(save_dir, f"{self.args.data_name}_growth_mass.png")
567
+ plt.savefig(mass_path, dpi=300, bbox_inches="tight")
568
+ plt.close()
569
+
570
+ # --- Plot Energy ---
571
+ plt.figure(figsize=(8, 5))
572
+ for i in range(self.branches):
573
+ color = branch_colors[i]
574
+ plt.plot(time_points, energy_over_time[i], color=color, linewidth=2.5, label=f"Energy Branch {i}")
575
+ plt.xlabel("Time")
576
+ plt.ylabel("Energy")
577
+ plt.title("Energy Evolution per Branch")
578
+ plt.legend()
579
+ plt.grid(True)
580
+ if self.joint:
581
+ energy_path = os.path.join(save_dir, f"{self.args.data_name}_joint_energy.png")
582
+ else:
583
+ energy_path = os.path.join(save_dir, f"{self.args.data_name}_growth_energy.png")
584
+ plt.savefig(energy_path, dpi=300, bbox_inches="tight")
585
+ plt.close()
586
+
587
+
588
+ class GrowthNetTrainLidar(GrowthNetTrain):
589
+ def test_step(self, batch, batch_idx):
590
+ # Handle both tuple and dict batch formats from CombinedLoader
591
+ if isinstance(batch, dict):
592
+ main_batch = batch["test_samples"][0]
593
+ metric_batch = batch["metric_samples"][0]
594
+ else:
595
+ # batch is a tuple: (test_samples, metric_samples)
596
+ main_batch = batch[0][0]
597
+ metric_batch = batch[1][0]
598
+
599
+ self._plot_mass_and_energy(main_batch, metric_batch)
600
+
601
+ x0 = main_batch["x0"][0] # [B, D]
602
+ cloud_points = main_batch["dataset"][0] # full dataset, [N, D]
603
+ t_span = torch.linspace(0, 1, 101)
604
+
605
+
606
+ all_trajs = []
607
+
608
+ for i, flow_net in enumerate(self.flow_nets):
609
+ node = NeuralODE(
610
+ flow_model_torch_wrapper(flow_net),
611
+ solver="euler",
612
+ sensitivity="adjoint",
613
+ )
614
+
615
+ with torch.no_grad():
616
+ traj = node.trajectory(x0, t_span).cpu() # [T, B, D]
617
+
618
+ if self.whiten:
619
+ traj_shape = traj.shape
620
+ traj = traj.reshape(-1, 3)
621
+ traj = self.trainer.datamodule.scaler.inverse_transform(
622
+ traj.cpu().detach().numpy()
623
+ ).reshape(traj_shape)
624
+
625
+ traj = torch.tensor(traj)
626
+ traj = torch.transpose(traj, 0, 1) # [B, T, D]
627
+ all_trajs.append(traj)
628
+
629
+ # Inverse-transform the point cloud once
630
+ if self.whiten:
631
+ cloud_points = torch.tensor(
632
+ self.trainer.datamodule.scaler.inverse_transform(
633
+ cloud_points.cpu().detach().numpy()
634
+ )
635
+ )
636
+
637
+ # ===== Plot all trajectories together =====
638
+ fig = plt.figure(figsize=(6, 5))
639
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
640
+ ax.view_init(elev=30, azim=-115, roll=0)
641
+ for i, traj in enumerate(all_trajs):
642
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
643
+ run_name = self.args.run_name if hasattr(self.args, 'run_name') and self.args.run_name else self.args.data_name
644
+ results_dir = os.path.join(self.args.working_dir, 'results', run_name)
645
+ lidar_fig_dir = os.path.join(results_dir, 'figures')
646
+ os.makedirs(lidar_fig_dir, exist_ok=True)
647
+ if self.joint:
648
+ plt.savefig(os.path.join(lidar_fig_dir, 'joint_lidar_all_branches.png'), dpi=300)
649
+ else:
650
+ plt.savefig(os.path.join(lidar_fig_dir, 'growth_lidar_all_branches.png'), dpi=300)
651
+ plt.close()
652
+
653
+ # ===== Plot each trajectory separately =====
654
+ for i, traj in enumerate(all_trajs):
655
+ fig = plt.figure(figsize=(6, 5))
656
+ ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
657
+ ax.view_init(elev=30, azim=-115, roll=0)
658
+ plot_lidar(ax, cloud_points, xs=traj, branch_idx=i)
659
+ if self.joint:
660
+ plt.savefig(os.path.join(lidar_fig_dir, f'joint_lidar_branch_{i + 1}.png'), dpi=300)
661
+ else:
662
+ plt.savefig(os.path.join(lidar_fig_dir, f'growth_lidar_branch_{i + 1}.png'), dpi=300)
663
+ plt.close()
664
+
665
+ class GrowthNetTrainCell(GrowthNetTrain):
666
+ def test_step(self, batch, batch_idx):
667
+ if self.args.data_type in ["scrna", "tahoe"]:
668
+ main_batch = batch[0]["test_samples"][0]
669
+ metric_batch = batch[0]["metric_samples"][0]
670
+ else:
671
+ main_batch = batch["test_samples"][0]
672
+ metric_batch = batch["metric_samples"][0]
673
+
674
+ self._plot_mass_and_energy(main_batch, metric_batch)
675
+
676
+
677
+ class SequentialGrowthNetTrain(pl.LightningModule):
678
+ """
679
+ Sequential growth network training for multi-timepoint data.
680
+ Learns growth rates for transitions between consecutive timepoints.
681
+ """
682
+ def __init__(
683
+ self,
684
+ flow_nets,
685
+ growth_nets,
686
+ skipped_time_points=None,
687
+ ot_sampler=None,
688
+ args=None,
689
+ data_manifold_metric=None,
690
+ joint=False
691
+ ):
692
+ super().__init__()
693
+ self.flow_nets = flow_nets
694
+
695
+ if not joint:
696
+ for param in self.flow_nets.parameters():
697
+ param.requires_grad = False
698
+
699
+ self.growth_nets = growth_nets
700
+ self.ot_sampler = ot_sampler
701
+ self.skipped_time_points = skipped_time_points
702
+
703
+ self.optimizer_name = args.growth_optimizer
704
+ self.lr = args.growth_lr
705
+ self.weight_decay = args.growth_weight_decay
706
+ self.whiten = args.whiten
707
+ self.working_dir = args.working_dir
708
+
709
+ self.args = args
710
+ self.data_manifold_metric = data_manifold_metric
711
+ self.branches = len(growth_nets)
712
+ self.metric_clusters = args.metric_clusters
713
+
714
+ self.recons_loss = ReconsLoss()
715
+
716
+ # loss weights
717
+ self.lambda_energy = args.lambda_energy
718
+ self.lambda_mass = args.lambda_mass
719
+ self.lambda_match = args.lambda_match
720
+ self.lambda_recons = args.lambda_recons
721
+
722
+ self.joint = joint
723
+ self.num_timepoints = None
724
+ self.timepoint_keys = None
725
+
726
+ def forward(self, t, xt, branch_idx):
727
+ return self.growth_nets[branch_idx](t, xt)
728
+
729
+ def setup(self, stage=None):
730
+ """Initialize timepoint keys before training/validation starts."""
731
+ if self.timepoint_keys is None:
732
+ timepoint_data = self.trainer.datamodule.get_timepoint_data()
733
+ self.timepoint_keys = [k for k in sorted(timepoint_data.keys())
734
+ if not any(x in k for x in ['_', 'time_labels'])]
735
+ self.num_timepoints = len(self.timepoint_keys)
736
+ print(f"Training sequential growth for {self.num_timepoints} timepoints: {self.timepoint_keys}")
737
+
738
+ def _compute_loss(self, main_batch, metric_samples_batch=None, validation=False):
739
+ """Compute loss for sequential growth between timepoints."""
740
+ x0s = main_batch["x0"][0]
741
+ w0s = main_batch["x0"][1]
742
+
743
+ # Setup metric sample pairs
744
+ if self.args.manifold:
745
+ if self.metric_clusters == 2:
746
+ branch_sample_pairs = [
747
+ (metric_samples_batch[0], metric_samples_batch[1])
748
+ ] * self.branches
749
+ else:
750
+ branch_sample_pairs = []
751
+ for b in range(self.branches):
752
+ if b + 1 < len(metric_samples_batch):
753
+ branch_sample_pairs.append(
754
+ (metric_samples_batch[0], metric_samples_batch[b + 1])
755
+ )
756
+ else:
757
+ branch_sample_pairs.append(
758
+ (metric_samples_batch[0], metric_samples_batch[1])
759
+ )
760
+
761
+ total_loss = 0
762
+ total_energy_loss = 0
763
+ total_mass_loss = 0
764
+ total_match_loss = 0
765
+ total_recons_loss = 0
766
+ num_transitions = 0
767
+
768
+ # Process each consecutive timepoint transition
769
+ for i in range(len(self.timepoint_keys) - 1):
770
+ t_curr_key = self.timepoint_keys[i]
771
+ t_next_key = self.timepoint_keys[i + 1]
772
+
773
+ batch_curr_key = f"x{t_curr_key.replace('t', '').replace('final', '1')}"
774
+ x_curr = main_batch[batch_curr_key][0]
775
+ w_curr = main_batch[batch_curr_key][1]
776
+
777
+ if i == len(self.timepoint_keys) - 2:
778
+ # Final transition to branches
779
+ # Get cluster size weights if available
780
+ if hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'cluster_sizes'):
781
+ cluster_sizes = self.trainer.datamodule.cluster_sizes
782
+ max_size = max(cluster_sizes)
783
+ # Inverse weighting: smaller clusters get higher weight
784
+ branch_weights = [max_size / size for size in cluster_sizes]
785
+ else:
786
+ branch_weights = [1.0] * self.branches
787
+
788
+ for b in range(self.branches):
789
+ x_next = main_batch[f"x1_{b+1}"][0]
790
+ w_next = main_batch[f"x1_{b+1}"][1]
791
+
792
+ # Compute growth-based loss for this transition
793
+ loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss(
794
+ x_curr, w_curr, x_next, w_next, b, i,
795
+ branch_sample_pairs[b] if self.args.manifold else None
796
+ )
797
+ # Apply branch weight
798
+ total_loss += loss * branch_weights[b]
799
+ total_energy_loss += energy_l * branch_weights[b]
800
+ total_mass_loss += mass_l * branch_weights[b]
801
+ total_match_loss += match_l * branch_weights[b]
802
+ total_recons_loss += recons_l * branch_weights[b]
803
+ num_transitions += 1
804
+ else:
805
+ # Regular consecutive timepoints
806
+ batch_next_key = f"x{t_next_key.replace('t', '').replace('final', '1')}"
807
+ x_next = main_batch[batch_next_key][0]
808
+ w_next = main_batch[batch_next_key][1]
809
+
810
+ for b in range(self.branches):
811
+ loss, energy_l, mass_l, match_l, recons_l = self._compute_transition_loss(
812
+ x_curr, w_curr, x_next, w_next, b, i,
813
+ branch_sample_pairs[b] if self.args.manifold else None
814
+ )
815
+ total_loss += loss
816
+ total_energy_loss += energy_l
817
+ total_mass_loss += mass_l
818
+ total_match_loss += match_l
819
+ total_recons_loss += recons_l
820
+ num_transitions += 1
821
+
822
+ # Average losses
823
+ avg_energy_loss = total_energy_loss / num_transitions if num_transitions > 0 else total_energy_loss
824
+ avg_mass_loss = total_mass_loss / num_transitions if num_transitions > 0 else total_mass_loss
825
+ avg_match_loss = total_match_loss / num_transitions if num_transitions > 0 else total_match_loss
826
+ avg_recons_loss = total_recons_loss / num_transitions if num_transitions > 0 else total_recons_loss
827
+
828
+ # Log individual components
829
+ if self.joint:
830
+ if validation:
831
+ self.log("JointTrain/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
832
+ self.log("JointTrain/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
833
+ self.log("JointTrain/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
834
+ self.log("JointTrain/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
835
+ else:
836
+ self.log("JointTrain/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
837
+ self.log("JointTrain/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
838
+ self.log("JointTrain/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
839
+ self.log("JointTrain/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
840
+ else:
841
+ if validation:
842
+ self.log("GrowthNet/val_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
843
+ self.log("GrowthNet/val_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
844
+ self.log("GrowthNet/val_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
845
+ self.log("GrowthNet/val_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
846
+ else:
847
+ self.log("GrowthNet/train_energy_loss", avg_energy_loss, on_step=False, on_epoch=True, prog_bar=True)
848
+ self.log("GrowthNet/train_mass_loss", avg_mass_loss, on_step=False, on_epoch=True, prog_bar=True)
849
+ self.log("GrowthNet/train_match_loss", avg_match_loss, on_step=False, on_epoch=True, prog_bar=True)
850
+ self.log("GrowthNet/train_recons_loss", avg_recons_loss, on_step=False, on_epoch=True, prog_bar=True)
851
+
852
+ return total_loss
853
+
854
+ def _compute_transition_loss(self, x0, w0, x1, w1, branch_idx, transition_idx, metric_pair):
855
+ """Compute loss for a single timepoint transition."""
856
+ if self.ot_sampler is not None:
857
+ x0, x1 = self.ot_sampler.sample_plan(x0, x1, replace=True)
858
+
859
+ # Simulate trajectory using flow network
860
+ t_span = torch.linspace(0, 1, 10, device=x0.device)
861
+
862
+ flow_model = flow_model_torch_wrapper(self.flow_nets[branch_idx])
863
+ node = NeuralODE(flow_model, solver="euler", sensitivity="adjoint")
864
+
865
+ with torch.no_grad():
866
+ traj = node.trajectory(x0, t_span)
867
+
868
+ # Compute energy and mass losses
869
+ energy_loss = 0
870
+ mass_loss = 0
871
+ neg_weight_penalty = 0
872
+
873
+ for t_idx in range(len(t_span)):
874
+ t = t_span[t_idx]
875
+ xt = traj[t_idx]
876
+
877
+ # Growth rate
878
+ growth = self.growth_nets[branch_idx](t.unsqueeze(0).expand(xt.shape[0]), xt)
879
+
880
+ # Energy loss
881
+ if self.args.manifold and metric_pair is not None:
882
+ start_samples, end_samples = metric_pair
883
+ samples = torch.cat([start_samples, end_samples], dim=0)
884
+ _, kinetic, potential = self.data_manifold_metric.calculate_velocity(
885
+ xt, torch.zeros_like(xt), samples, transition_idx
886
+ )
887
+ energy = kinetic + potential
888
+ else:
889
+ energy = (growth ** 2).sum(dim=-1)
890
+
891
+ energy_loss += energy.mean()
892
+
893
+ # Mass conservation
894
+ growth_sum = growth.sum(dim=-1, keepdim=True) # Keep dimension for proper broadcasting
895
+ wt = w0 * torch.exp(growth_sum)
896
+ mass = wt.sum()
897
+ mass_loss += (mass - w1.sum()).abs()
898
+ neg_weight_penalty += torch.relu(-wt).sum()
899
+
900
+ # Match and reconstruction losses (computed at final time)
901
+ xt_final = traj[-1]
902
+ match_loss = mean_squared_error(wt, w1)
903
+ recons_loss = self.recons_loss(xt_final, x1)
904
+
905
+ total_loss = (
906
+ self.lambda_energy * energy_loss +
907
+ self.lambda_mass * (mass_loss + neg_weight_penalty) +
908
+ self.lambda_match * match_loss +
909
+ self.lambda_recons * recons_loss
910
+ )
911
+
912
+ return total_loss, energy_loss, mass_loss + neg_weight_penalty, match_loss, recons_loss
913
+
914
+ def training_step(self, batch, batch_idx):
915
+ if isinstance(batch, (list, tuple)):
916
+ batch = batch[0]
917
+ main_batch = batch["train_samples"]
918
+ metric_batch = batch["metric_samples"]
919
+ if isinstance(main_batch, tuple):
920
+ main_batch = main_batch[0]
921
+ if isinstance(metric_batch, tuple):
922
+ metric_batch = metric_batch[0]
923
+
924
+ loss = self._compute_loss(main_batch, metric_batch)
925
+
926
+ if self.joint:
927
+ self.log(
928
+ "JointTrain/train_loss",
929
+ loss,
930
+ on_step=False,
931
+ on_epoch=True,
932
+ prog_bar=True,
933
+ logger=True,
934
+ )
935
+ else:
936
+ self.log(
937
+ "GrowthNet/train_loss",
938
+ loss,
939
+ on_step=False,
940
+ on_epoch=True,
941
+ prog_bar=True,
942
+ logger=True,
943
+ )
944
+
945
+ return loss
946
+
947
+ def validation_step(self, batch, batch_idx):
948
+ if isinstance(batch, (list, tuple)):
949
+ batch = batch[0]
950
+ main_batch = batch["val_samples"]
951
+ metric_batch = batch["metric_samples"]
952
+ if isinstance(main_batch, tuple):
953
+ main_batch = main_batch[0]
954
+ if isinstance(metric_batch, tuple):
955
+ metric_batch = metric_batch[0]
956
+
957
+ loss = self._compute_loss(main_batch, metric_batch, validation=True)
958
+
959
+ if self.joint:
960
+ self.log(
961
+ "JointTrain/val_loss",
962
+ loss,
963
+ on_step=False,
964
+ on_epoch=True,
965
+ prog_bar=True,
966
+ logger=True,
967
+ )
968
+ else:
969
+ self.log(
970
+ "GrowthNet/val_loss",
971
+ loss,
972
+ on_step=False,
973
+ on_epoch=True,
974
+ prog_bar=True,
975
+ logger=True,
976
+ )
977
+
978
+ return loss
979
+
980
+ def configure_optimizers(self):
981
+ import itertools
982
+ params = list(itertools.chain(*[net.parameters() for net in self.growth_nets]))
983
+ if self.joint:
984
+ params += list(itertools.chain(*[net.parameters() for net in self.flow_nets]))
985
+
986
+ if self.optimizer_name == "adam":
987
+ optimizer = torch.optim.Adam(params, lr=self.lr)
988
+ elif self.optimizer_name == "adamw":
989
+ optimizer = torch.optim.AdamW(
990
+ params,
991
+ lr=self.lr,
992
+ weight_decay=self.weight_decay,
993
+ )
994
+ return optimizer