import sys sys.path.append("./BranchSBM") import os import sys import argparse import copy from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger import wandb import hydra from omegaconf import DictConfig, OmegaConf from torchcfm.optimal_transport import OTPlanSampler from branchsbm.branchsbm import BranchSBM from branchsbm.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar from branchsbm.branch_interpolant_train import BranchInterpolantTrain from branchsbm.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar from dataloaders.trajectory_data import TemporalDataModule from dataloaders.mouse_data import WeightedBranchedCellDataModule from dataloaders.three_branch_data import ThreeBranchTahoeDataModule from dataloaders.clonidine_v2_data import ClonidineV2DataModule from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule from dataloaders.trametinib_single import TrametinibSingleBranchDataModule from dataloaders.lidar_data import WeightedBranchedLidarDataModule from dataloaders.lidar_data_single import LidarSingleDataModule from networks.flow_mlp import VelocityNet from networks.growth_mlp import GrowthNet from networks.interpolant_mlp import GeoPathMLP from utils import set_seed from train.parsers import parse_args from branchsbm.ema import EMA from train.train_utils import ( load_config, merge_config, generate_group_string, dataset_name2datapath, create_callbacks, ) from state_costs.metric_factory import DataManifoldMetric import torch.nn as nn def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None: set_seed(seed) branches = args.branches skipped_time_points = [t_exclude] if t_exclude else [] ### DATAMODULES ### if args.data_name == "lidar": datamodule = WeightedBranchedLidarDataModule(args=args) elif args.data_name == "lidarsingle": datamodule = LidarSingleDataModule(args=args) elif args.data_name == "mouse": datamodule = WeightedBranchedCellDataModule(args=args) elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]: datamodule = ClonidineV2DataModule(args=args) elif args.data_name == "clonidine50Dsingle": datamodule = ClonidineSingleBranchDataModule(args=args) elif args.data_name == "trametinib": datamodule = ThreeBranchTahoeDataModule(args=args) elif args.data_name == "trametinibsingle": datamodule = TrametinibSingleBranchDataModule(args=args) flow_nets = nn.ModuleList() geopath_nets = nn.ModuleList() growth_nets = nn.ModuleList() ##### initialize branched flow and growth networks ##### for i in range(branches): flow_net = VelocityNet( dim=args.dim, hidden_dims=args.hidden_dims_flow, activation=args.activation_flow, batch_norm=False, ) geopath_net = GeoPathMLP( input_dim=args.dim, hidden_dims=args.hidden_dims_geopath, time_geopath=args.time_geopath, activation=args.activation_geopath, batch_norm=False, ) if i == 0: growth_net = GrowthNet( dim=args.dim, hidden_dims=args.hidden_dims_growth, activation=args.activation_growth, batch_norm=False, negative=True ) else: growth_net = GrowthNet( dim=args.dim, hidden_dims=args.hidden_dims_growth, activation=args.activation_growth, batch_norm=False, negative=False ) if args.ema_decay is not None: flow_net = EMA(model=flow_net, decay=args.ema_decay) geopath_net = EMA(model=geopath_net, decay=args.ema_decay) growth_net = EMA(model=growth_net, decay=args.ema_decay) flow_nets.append(flow_net) geopath_nets.append(geopath_net) growth_nets.append(growth_net) ot_sampler = ( OTPlanSampler(method=args.optimal_transport_method) if args.optimal_transport_method != "None" else None ) wandb.init( project=f"branchsbm-{args.data_name}-{branches}-branches", group=args.group_name, config=vars(args), dir=args.working_dir, ) flow_matcher_base = BranchSBM( geopath_nets=geopath_nets, sigma=args.sigma, alpha=int(args.branchsbm), ) ##### STAGE 1: Training of Geodesic Interpolants Beginning ##### geopath_callbacks = create_callbacks( args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id ) # define state cost data_manifold_metric = DataManifoldMetric( args=args, skipped_time_points=skipped_time_points, datamodule=datamodule, ) geopath_model = BranchInterpolantTrain( flow_matcher=flow_matcher_base, skipped_time_points=skipped_time_points, ot_sampler=ot_sampler, args=args, data_manifold_metric=data_manifold_metric ) wandb_logger = WandbLogger() trainer = Trainer( max_epochs=args.epochs, callbacks=geopath_callbacks, accelerator=args.accelerator, logger=wandb_logger, num_sanity_val_steps=0, default_root_dir=args.working_dir, gradient_clip_val=(1.0 if args.data_type == "image" else None), ) if args.load_geopath_model_ckpt: best_model_path = args.load_geopath_model_ckpt else: trainer.fit( geopath_model, datamodule=datamodule, ) best_model_path = geopath_callbacks[0].best_model_path geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path) flow_matcher_base.geopath_nets = geopath_model.geopath_nets ##### STAGE 1: Training of Geodesic Interpolants End ##### ##### STAGE 2: Flow Matching Beginning ##### flow_callbacks = create_callbacks( args, phase="flow", data_type=args.data_type, run_id=wandb.run.id, datamodule=datamodule, ) if args.data_type == "lidar": FlowNetTrain = FlowNetTrainLidar else: FlowNetTrain = FlowNetTrainCell flow_train = FlowNetTrain( flow_matcher=flow_matcher_base, flow_nets=flow_nets, ot_sampler=ot_sampler, skipped_time_points=skipped_time_points, args=args, ) wandb_logger = WandbLogger() trainer = Trainer( max_epochs=args.epochs, callbacks=flow_callbacks, check_val_every_n_epoch=args.check_val_every_n_epoch, accelerator=args.accelerator, logger=wandb_logger, default_root_dir=args.working_dir, gradient_clip_val=(1.0 if args.data_type == "image" else None), num_sanity_val_steps=(0 if args.data_type == "image" else None), ) trainer.fit( flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt ) if args.data_type == "lidar": trainer.test(flow_train, datamodule=datamodule) ##### STAGE 2: Flow Matching End ##### ##### STAGE 3: Training Growth Networks Beginning #### flow_nets = flow_train.flow_nets growth_callbacks = create_callbacks( args, phase="growth", data_type=args.data_type, run_id=wandb.run.id, datamodule=datamodule, ) if args.data_type == "lidar": GrowthNetTrain = GrowthNetTrainLidar else: GrowthNetTrain = GrowthNetTrainCell growth_train = GrowthNetTrain( flow_nets = flow_nets, growth_nets = growth_nets, ot_sampler=ot_sampler, skipped_time_points=skipped_time_points, args=args, data_manifold_metric=data_manifold_metric, joint = False ) wandb_logger = WandbLogger() trainer = Trainer( max_epochs=args.epochs, callbacks=growth_callbacks, check_val_every_n_epoch=args.check_val_every_n_epoch, accelerator=args.accelerator, logger=wandb_logger, default_root_dir=args.working_dir, gradient_clip_val=(1.0 if args.data_type == "image" else None), num_sanity_val_steps=(0 if args.data_type == "image" else None), ) trainer.fit( growth_train, datamodule=datamodule, ckpt_path=None ) trainer.test(growth_train, datamodule=datamodule) ##### STAGE 3: Training Growth Networks End #### ##### STAGE 4: Joint Training Beginning #### growth_nets = growth_train.growth_nets joint_callbacks = create_callbacks( args, phase="joint", data_type=args.data_type, run_id=wandb.run.id, datamodule=datamodule, ) if args.data_type == "lidar": GrowthNetTrain = GrowthNetTrainLidar else: GrowthNetTrain = GrowthNetTrainCell joint_train = GrowthNetTrain( flow_nets = flow_nets, growth_nets = growth_nets, ot_sampler=ot_sampler, skipped_time_points=skipped_time_points, args=args, data_manifold_metric=data_manifold_metric, joint = True ) wandb_logger = WandbLogger() trainer = Trainer( max_epochs=args.epochs, callbacks=joint_callbacks, check_val_every_n_epoch=args.check_val_every_n_epoch, accelerator=args.accelerator, logger=wandb_logger, default_root_dir=args.working_dir, gradient_clip_val=(1.0 if args.data_type == "image" else None), num_sanity_val_steps=(0 if args.data_type == "image" else None), ) trainer.fit( joint_train, datamodule=datamodule, ckpt_path=None ) trainer.test(joint_train, datamodule=datamodule) ##### STAGE 4: Joint Training End #### wandb.finish() if __name__ == "__main__": args = parse_args() updated_args = copy.deepcopy(args) if args.config_path: config = load_config(args.config_path) updated_args = merge_config(updated_args, config) updated_args.group_name = generate_group_string() updated_args.data_path = dataset_name2datapath( updated_args.data_name, updated_args.working_dir ) for seed in updated_args.seeds: if updated_args.t_exclude: for i, t_exclude in enumerate(updated_args.t_exclude): updated_args.t_exclude_current = t_exclude updated_args.seed_current = seed updated_args.gamma_current = updated_args.gammas[i] main(updated_args, seed=seed, t_exclude=t_exclude) else: updated_args.seed_current = seed updated_args.gamma_current = updated_args.gammas[0] main(updated_args, seed=seed, t_exclude=None)