|
|
import os |
|
|
import sys |
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "6" |
|
|
|
|
|
import torch |
|
|
import wandb |
|
|
|
|
|
from entangledcell_module_unseen import EntangledNetTrainCellUnseen |
|
|
from entangledcell_module_three import EntangledNetTrainCellThree |
|
|
|
|
|
|
|
|
from dataloaders.three_branch_data import ThreeBranchTahoeDataModule |
|
|
from dataloaders.clonidine_v2_data import ClonidineV2DataModule |
|
|
|
|
|
from geo_metrics.metric_factory import DataManifoldMetric |
|
|
from pytorch_lightning.loggers import WandbLogger |
|
|
from pytorch_lightning import Trainer |
|
|
|
|
|
from torchcfm.optimal_transport import OTPlanSampler |
|
|
from parser import parse_args |
|
|
from train_utils import load_config, merge_config |
|
|
from bias import BiasForceTransformer, BiasForceTransformerNoVel |
|
|
|
|
|
def main(): |
|
|
|
|
|
args = parse_args() |
|
|
if args.config_path: |
|
|
config = load_config(args.config_path) |
|
|
args = merge_config(args, config) |
|
|
|
|
|
args.training = True |
|
|
args.save_dir = args.save_dir |
|
|
|
|
|
|
|
|
positions_dir = f"{args.save_dir}/positions" |
|
|
if not os.path.exists(positions_dir): |
|
|
os.makedirs(positions_dir) |
|
|
|
|
|
wandb.init(project="entangled-cell", |
|
|
config=args, |
|
|
name=args.run_name) |
|
|
|
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
ot_sampler = ( |
|
|
OTPlanSampler(method=args.optimal_transport_method) |
|
|
if args.optimal_transport_method != "None" |
|
|
else None |
|
|
) |
|
|
|
|
|
|
|
|
if args.data_name == "trametinib": |
|
|
datamodule = ThreeBranchTahoeDataModule(args=args) |
|
|
else: |
|
|
datamodule = ClonidineV2DataModule(args=args) |
|
|
|
|
|
|
|
|
data_manifold_metric = DataManifoldMetric( |
|
|
args=args, |
|
|
skipped_time_points=[], |
|
|
datamodule=datamodule, |
|
|
) |
|
|
|
|
|
if args.vel_conditioned: |
|
|
bias_net = BiasForceTransformer(args) |
|
|
else: |
|
|
print("Using no velocity conditioned model") |
|
|
bias_net = BiasForceTransformerNoVel(args) |
|
|
|
|
|
timepoint_data = datamodule.get_timepoint_data() |
|
|
|
|
|
if args.data_name == "trametinib": |
|
|
entangled_train = EntangledNetTrainCellThree(args=args, |
|
|
bias_net=bias_net, |
|
|
data_manifold_metric=data_manifold_metric, |
|
|
timepoint_data=timepoint_data, |
|
|
ot_sampler=ot_sampler, |
|
|
vel_conditioned=args.vel_conditioned) |
|
|
else: |
|
|
entangled_train = EntangledNetTrainCellUnseen(args=args, |
|
|
bias_net=bias_net, |
|
|
data_manifold_metric=data_manifold_metric, |
|
|
timepoint_data=timepoint_data, |
|
|
ot_sampler=ot_sampler, |
|
|
vel_conditioned=args.vel_conditioned) |
|
|
|
|
|
wandb_logger = WandbLogger() |
|
|
|
|
|
trainer = Trainer( |
|
|
max_epochs=args.num_rollouts, |
|
|
logger=wandb_logger, |
|
|
num_sanity_val_steps=0, |
|
|
default_root_dir=args.root_dir, |
|
|
gradient_clip_val=None, |
|
|
devices=[0], |
|
|
) |
|
|
|
|
|
trainer.fit( |
|
|
entangled_train, datamodule=datamodule |
|
|
) |
|
|
trainer.test(entangled_train, datamodule=datamodule) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |