EntangledSBM / entangled-cell /train_cell.py
Sophia Tang
Initial commit with LFS
7efee70
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import torch
import wandb
from entangledcell_module_unseen import EntangledNetTrainCellUnseen
from entangledcell_module_three import EntangledNetTrainCellThree
# cell
from dataloaders.three_branch_data import ThreeBranchTahoeDataModule
from dataloaders.clonidine_v2_data import ClonidineV2DataModule
from geo_metrics.metric_factory import DataManifoldMetric
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from torchcfm.optimal_transport import OTPlanSampler
from parser import parse_args
from train_utils import load_config, merge_config
from bias import BiasForceTransformer, BiasForceTransformerNoVel
def main():
args = parse_args()
if args.config_path:
config = load_config(args.config_path)
args = merge_config(args, config)
args.training = True
args.save_dir = args.save_dir
# Create positions directory for saving trajectory samples
positions_dir = f"{args.save_dir}/positions"
if not os.path.exists(positions_dir):
os.makedirs(positions_dir)
wandb.init(project="entangled-cell",
config=args,
name=args.run_name)
torch.manual_seed(args.seed)
ot_sampler = (
OTPlanSampler(method=args.optimal_transport_method)
if args.optimal_transport_method != "None"
else None
)
# get data
if args.data_name == "trametinib":
datamodule = ThreeBranchTahoeDataModule(args=args)
else:
datamodule = ClonidineV2DataModule(args=args)
# data manifold metrics
data_manifold_metric = DataManifoldMetric(
args=args,
skipped_time_points=[],
datamodule=datamodule,
)
if args.vel_conditioned:
bias_net = BiasForceTransformer(args)
else:
print("Using no velocity conditioned model")
bias_net = BiasForceTransformerNoVel(args)
timepoint_data = datamodule.get_timepoint_data()
if args.data_name == "trametinib":
entangled_train = EntangledNetTrainCellThree(args=args,
bias_net=bias_net,
data_manifold_metric=data_manifold_metric,
timepoint_data=timepoint_data,
ot_sampler=ot_sampler,
vel_conditioned=args.vel_conditioned)
else:
entangled_train = EntangledNetTrainCellUnseen(args=args,
bias_net=bias_net,
data_manifold_metric=data_manifold_metric,
timepoint_data=timepoint_data,
ot_sampler=ot_sampler,
vel_conditioned=args.vel_conditioned)
wandb_logger = WandbLogger()
trainer = Trainer(
max_epochs=args.num_rollouts,
logger=wandb_logger,
num_sanity_val_steps=0,
default_root_dir=args.root_dir,
gradient_clip_val=None,
devices=[0],
)
trainer.fit(
entangled_train, datamodule=datamodule
)
trainer.test(entangled_train, datamodule=datamodule)
if __name__ == "__main__":
main()