| | import sys |
| | import os |
| | import argparse |
| | import copy |
| | import time |
| | import json |
| |
|
| | import torch.nn as nn |
| | import wandb |
| | from pytorch_lightning import Trainer |
| | from pytorch_lightning.loggers import WandbLogger |
| | from torchcfm.optimal_transport import OTPlanSampler |
| |
|
| | from parsers import parse_args |
| | from train_utils import load_config, merge_config, generate_group_string, dataset_name2datapath, create_callbacks |
| | from src.branchsbm import BranchSBM |
| | from src.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar |
| | from src.branch_flow_net_test import ( |
| | FlowNetTestLidar, FlowNetTestMouse, FlowNetTestClonidine, FlowNetTestTrametinib, FlowNetTestVeres |
| | ) |
| | from src.branch_interpolant_train import BranchInterpolantTrain |
| | from src.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar, SequentialGrowthNetTrain |
| | from src.networks.flow_mlp import VelocityNet |
| | from src.networks.growth_mlp import GrowthNet |
| | from src.networks.interpolant_mlp import GeoPathMLP |
| | from src.utils import set_seed |
| | from src.ema import EMA |
| | from src.geo_metrics.metric_factory import DataManifoldMetric |
| | from dataloaders.mouse_data import WeightedBranchedCellDataModule, SingleBranchCellDataModule |
| | from dataloaders.three_branch_data import ThreeBranchTahoeDataModule |
| | from dataloaders.clonidine_v2_data import ClonidineV2DataModule |
| | from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule |
| | from dataloaders.trametinib_single import TrametinibSingleBranchDataModule |
| | from dataloaders.lidar_data import WeightedBranchedLidarDataModule |
| | from dataloaders.lidar_data_single import LidarSingleDataModule |
| | from dataloaders.veres_leiden_data import WeightedBranchedVeresDataModule |
| |
|
| | def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None: |
| | set_seed(seed) |
| | branches = args.branches |
| |
|
| | skipped_time_points = [t_exclude] if t_exclude else [] |
| | print("config path:") |
| | print(args.config_path) |
| | print("whiten") |
| | print(args.whiten) |
| | |
| | |
| | current_datetime = time.strftime("%m_%d_%H%M", time.localtime()) |
| | run_name_with_datetime = f"{current_datetime}_{args.run_name}" |
| | |
| | |
| | args.run_name = run_name_with_datetime |
| | |
| | |
| | |
| | |
| | if args.data_name == "lidar": |
| | datamodule = WeightedBranchedLidarDataModule(args=args) |
| | elif args.data_name == "lidarsingle": |
| | datamodule = LidarSingleDataModule(args=args) |
| | elif args.data_name == "mouse": |
| | datamodule = WeightedBranchedCellDataModule(args=args) |
| | elif args.data_name == "mousesingle": |
| | datamodule = SingleBranchCellDataModule(args=args) |
| | elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]: |
| | datamodule = ClonidineV2DataModule(args=args) |
| | elif args.data_name == "clonidine50Dsingle": |
| | datamodule = ClonidineSingleBranchDataModule(args=args) |
| | elif args.data_name == "trametinib": |
| | datamodule = ThreeBranchTahoeDataModule(args=args) |
| | elif args.data_name == "trametinibsingle": |
| | datamodule = TrametinibSingleBranchDataModule(args=args) |
| | elif args.data_name == "veres": |
| | datamodule = WeightedBranchedVeresDataModule(args=args) |
| | branches = datamodule.num_branches |
| | print("number of branches:", branches) |
| | |
| | flow_nets = nn.ModuleList() |
| | geopath_nets = nn.ModuleList() |
| | growth_nets = nn.ModuleList() |
| | |
| | |
| | for i in range(branches): |
| | flow_net = VelocityNet( |
| | dim=args.dim, |
| | hidden_dims=args.hidden_dims_flow, |
| | activation=args.activation_flow, |
| | batch_norm=False, |
| | ) |
| | geopath_net = GeoPathMLP( |
| | input_dim=args.dim, |
| | hidden_dims=args.hidden_dims_geopath, |
| | time_geopath=args.time_geopath, |
| | activation=args.activation_geopath, |
| | batch_norm=False, |
| | ) |
| | |
| | if i == 0: |
| | growth_net = GrowthNet( |
| | dim=args.dim, |
| | hidden_dims=args.hidden_dims_growth, |
| | activation=args.activation_growth, |
| | batch_norm=False, |
| | negative=True |
| | ) |
| | else: |
| | growth_net = GrowthNet( |
| | dim=args.dim, |
| | hidden_dims=args.hidden_dims_growth, |
| | activation=args.activation_growth, |
| | batch_norm=False, |
| | negative=False |
| | ) |
| | |
| | if args.ema_decay is not None: |
| | flow_net = EMA(model=flow_net, decay=args.ema_decay) |
| | geopath_net = EMA(model=geopath_net, decay=args.ema_decay) |
| | growth_net = EMA(model=growth_net, decay=args.ema_decay) |
| | |
| | flow_nets.append(flow_net) |
| | geopath_nets.append(geopath_net) |
| | growth_nets.append(growth_net) |
| | |
| | |
| | ot_sampler = ( |
| | OTPlanSampler(method=args.optimal_transport_method) |
| | if args.optimal_transport_method != "None" |
| | else None |
| | ) |
| |
|
| | wandb.init( |
| | project="branchsbm", |
| | name=run_name_with_datetime, |
| | config=vars(args), |
| | dir=args.working_dir, |
| | ) |
| |
|
| | flow_matcher_base = BranchSBM( |
| | geopath_nets=geopath_nets, |
| | sigma=args.sigma, |
| | alpha=int(args.branchsbm), |
| | ) |
| |
|
| | |
| | geopath_callbacks = create_callbacks( |
| | args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id |
| | ) |
| | |
| | |
| | data_manifold_metric = DataManifoldMetric( |
| | args=args, |
| | skipped_time_points=skipped_time_points, |
| | datamodule=datamodule, |
| | ) |
| | geopath_model = BranchInterpolantTrain( |
| | flow_matcher=flow_matcher_base, |
| | skipped_time_points=skipped_time_points, |
| | ot_sampler=ot_sampler, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric |
| | ) |
| | |
| | wandb_logger = WandbLogger(version=run_name_with_datetime) |
| |
|
| | trainer = Trainer( |
| | max_epochs=args.epochs, |
| | callbacks=geopath_callbacks, |
| | accelerator=args.accelerator, |
| | logger=wandb_logger, |
| | num_sanity_val_steps=0, |
| | default_root_dir=args.working_dir, |
| | gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| | ) |
| | |
| | if args.load_geopath_model_ckpt: |
| | best_model_path = args.load_geopath_model_ckpt |
| | else: |
| | trainer.fit( |
| | geopath_model, |
| | datamodule=datamodule, |
| | ) |
| | |
| | best_model_path = geopath_callbacks[0].best_model_path |
| | |
| | geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path) |
| |
|
| | flow_matcher_base.geopath_nets = geopath_model.geopath_nets |
| |
|
| | |
| |
|
| | |
| | flow_callbacks = create_callbacks( |
| | args, |
| | phase="flow", |
| | data_type=args.data_type, |
| | run_id=wandb.run.id, |
| | datamodule=datamodule, |
| | ) |
| | |
| | if args.data_type == "lidar": |
| | FlowNetTrain = FlowNetTrainLidar |
| | else: |
| | FlowNetTrain = FlowNetTrainCell |
| |
|
| | flow_train = FlowNetTrain( |
| | flow_matcher=flow_matcher_base, |
| | flow_nets=flow_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | ) |
| |
|
| | |
| | wandb_logger = WandbLogger(version=run_name_with_datetime) |
| |
|
| | trainer = Trainer( |
| | max_epochs=args.epochs, |
| | callbacks=flow_callbacks, |
| | check_val_every_n_epoch=args.check_val_every_n_epoch, |
| | accelerator=args.accelerator, |
| | logger=wandb_logger, |
| | default_root_dir=args.working_dir, |
| | gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| | num_sanity_val_steps=(0 if args.data_type == "image" else None), |
| | ) |
| |
|
| | trainer.fit( |
| | flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt |
| | ) |
| | if args.data_type == "lidar": |
| | trainer.test(flow_train, datamodule=datamodule) |
| | |
| | |
| | |
| | flow_nets = flow_train.flow_nets |
| | |
| | growth_callbacks = create_callbacks( |
| | args, |
| | phase="growth", |
| | data_type=args.data_type, |
| | run_id=wandb.run.id, |
| | datamodule=datamodule, |
| | ) |
| |
|
| | if args.data_type == "lidar": |
| | GrowthNetTrainClass = GrowthNetTrainLidar |
| | else: |
| | GrowthNetTrainClass = GrowthNetTrainCell |
| | |
| | growth_train = GrowthNetTrainClass( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = False |
| | ) |
| |
|
| | |
| | wandb_logger = WandbLogger(version=run_name_with_datetime) |
| |
|
| | trainer = Trainer( |
| | max_epochs=args.epochs, |
| | callbacks=growth_callbacks, |
| | check_val_every_n_epoch=args.check_val_every_n_epoch, |
| | accelerator=args.accelerator, |
| | logger=wandb_logger, |
| | default_root_dir=args.working_dir, |
| | gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| | num_sanity_val_steps=(0 if args.data_type == "image" else None), |
| | ) |
| | |
| | trainer.fit( |
| | growth_train, datamodule=datamodule, ckpt_path=None |
| | ) |
| | |
| | |
| | best_growth_path = growth_callbacks[0].best_model_path |
| | if best_growth_path: |
| | print(f"Loading best growth model from: {best_growth_path}") |
| | if args.sequential: |
| | growth_train = SequentialGrowthNetTrain.load_from_checkpoint( |
| | best_growth_path, |
| | flow_nets=flow_nets, |
| | growth_nets=growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint=False |
| | ) |
| | else: |
| | growth_train = GrowthNetTrainClass.load_from_checkpoint( |
| | best_growth_path, |
| | flow_nets=flow_nets, |
| | growth_nets=growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint=False |
| | ) |
| | |
| | flow_nets = growth_train.flow_nets |
| | |
| | if isinstance(flow_nets, tuple): |
| | flow_nets = nn.ModuleList(flow_nets) |
| | if isinstance(growth_nets, tuple): |
| | growth_nets = nn.ModuleList(growth_nets) |
| | |
| | |
| | if "lidar" in args.data_name.lower(): |
| | test_model = FlowNetTestLidar( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = False |
| | ) |
| | elif "mouse" in args.data_name.lower(): |
| | test_model = FlowNetTestMouse( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = False |
| | ) |
| | elif "clonidine" in args.data_name.lower(): |
| | test_model = FlowNetTestClonidine( |
| | flow_matcher=flow_matcher_base, |
| | flow_nets=flow_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | ) |
| | elif "trametinib" in args.data_name.lower(): |
| | test_model = FlowNetTestTrametinib( |
| | flow_matcher=flow_matcher_base, |
| | flow_nets=flow_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | ) |
| | elif "veres" in args.data_name.lower(): |
| | test_model = FlowNetTestVeres( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = False |
| | ) |
| | else: |
| | |
| | test_model = growth_train |
| | |
| | trainer.test(test_model, datamodule=datamodule) |
| | |
| | |
| | |
| | |
| | |
| | growth_nets = growth_train.growth_nets |
| | |
| | joint_callbacks = create_callbacks( |
| | args, |
| | phase="joint", |
| | data_type=args.data_type, |
| | run_id=wandb.run.id, |
| | datamodule=datamodule, |
| | ) |
| | |
| | if args.sequential: |
| | joint_train = SequentialGrowthNetTrain( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = True |
| | ) |
| | else: |
| | if args.data_type == "lidar": |
| | GrowthNetTrainClass = GrowthNetTrainLidar |
| | else: |
| | GrowthNetTrainClass = GrowthNetTrainCell |
| | |
| | joint_train = GrowthNetTrainClass( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = True |
| | ) |
| | |
| | |
| | wandb_logger = WandbLogger(version=run_name_with_datetime) |
| |
|
| | trainer = Trainer( |
| | max_epochs=args.epochs, |
| | callbacks=joint_callbacks, |
| | check_val_every_n_epoch=args.check_val_every_n_epoch, |
| | accelerator=args.accelerator, |
| | logger=wandb_logger, |
| | default_root_dir=args.working_dir, |
| | gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| | num_sanity_val_steps=(0 if args.data_type == "image" else None), |
| | ) |
| | |
| | trainer.fit( |
| | joint_train, datamodule=datamodule, ckpt_path=None |
| | ) |
| | |
| | |
| | best_joint_path = joint_callbacks[0].best_model_path |
| | if best_joint_path: |
| | print(f"Loading best joint model from: {best_joint_path}") |
| | if args.sequential: |
| | joint_train = SequentialGrowthNetTrain.load_from_checkpoint( |
| | best_joint_path, |
| | flow_nets=flow_nets, |
| | growth_nets=growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint=True |
| | ) |
| | else: |
| | joint_train = GrowthNetTrainClass.load_from_checkpoint( |
| | best_joint_path, |
| | flow_nets=flow_nets, |
| | growth_nets=growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint=True |
| | ) |
| | |
| | flow_nets = joint_train.flow_nets |
| | growth_nets = joint_train.growth_nets |
| | |
| | if isinstance(flow_nets, tuple): |
| | flow_nets = nn.ModuleList(flow_nets) |
| | if isinstance(growth_nets, tuple): |
| | growth_nets = nn.ModuleList(growth_nets) |
| | |
| | |
| | if "lidar" in args.data_name.lower(): |
| | test_model = FlowNetTestLidar( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = True |
| | ) |
| | elif "mouse" in args.data_name.lower(): |
| | test_model = FlowNetTestMouse( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = True |
| | ) |
| | elif "clonidine" in args.data_name.lower(): |
| | test_model = FlowNetTestClonidine( |
| | flow_matcher=flow_matcher_base, |
| | flow_nets=flow_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | ) |
| | elif "trametinib" in args.data_name.lower(): |
| | test_model = FlowNetTestTrametinib( |
| | flow_matcher=flow_matcher_base, |
| | flow_nets=flow_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | ) |
| | elif "veres" in args.data_name.lower(): |
| | test_model = FlowNetTestVeres( |
| | flow_nets = flow_nets, |
| | growth_nets = growth_nets, |
| | ot_sampler=ot_sampler, |
| | skipped_time_points=skipped_time_points, |
| | args=args, |
| | data_manifold_metric=data_manifold_metric, |
| | joint = True |
| | ) |
| | else: |
| | test_model = joint_train |
| | test_model = joint_train |
| | |
| | trainer.test(test_model, datamodule=datamodule) |
| | |
| | |
| | |
| | wandb.finish() |
| | |
| | if __name__ == "__main__": |
| | args = parse_args() |
| | updated_args = copy.deepcopy(args) |
| | if args.config_path: |
| | config = load_config(args.config_path) |
| | updated_args = merge_config(updated_args, config) |
| |
|
| | updated_args.group_name = generate_group_string() |
| | updated_args.data_path = dataset_name2datapath( |
| | updated_args.data_name, updated_args.working_dir |
| | ) |
| | for seed in updated_args.seeds: |
| | if updated_args.t_exclude: |
| | for i, t_exclude in enumerate(updated_args.t_exclude): |
| | updated_args.t_exclude_current = t_exclude |
| | updated_args.seed_current = seed |
| | updated_args.gamma_current = updated_args.gammas[i] |
| | main(updated_args, seed=seed, t_exclude=t_exclude) |
| | else: |
| | updated_args.seed_current = seed |
| | updated_args.gamma_current = updated_args.gammas[0] |
| | main(updated_args, seed=seed, t_exclude=None) |
| |
|