BranchSBM / train /parsers.py
sophiat44
model upload
5a87d8d
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Train BranchSBM")
parser.add_argument(
"--config_path", type=str,
default='./configs/experiment/lidar.yaml',
help="Path to config file"
)
####### ITERATES IN THE CODE #######
parser.add_argument(
"--seeds",
nargs="+",
type=int,
default=[42, 43, 44, 45, 46],
help="Random seeds to iterate over",
)
parser.add_argument(
"--t_exclude",
nargs="+",
type=int,
default=[1, 2],
help="Time points to exclude (iterating over)",
)
####################################
parser.add_argument(
"--working_dir",
type=str,
default="./",
help="Working directory",
)
parser.add_argument(
"--resume_flow_model_ckpt",
type=str,
default=None,
help="Path to the flow model to resume training",
)
parser.add_argument(
"--resume_growth_model_ckpt",
type=str,
default=None,
help="Path to the flow model to resume training",
)
parser.add_argument(
"--load_geopath_model_ckpt",
type=str,
default=None,
help="Path to the geopath model to resume training",
)
parser.add_argument(
"--branches",
type=int,
default=2,
help="Number of branches",
)
parser.add_argument(
"--metric_clusters",
type=int,
default=3,
help="Number of metric clusters",
)
######### DATASETS #################
parser = datasets_parser(parser)
####################################
######### IMAGE DATASETS ###########
parser = image_datasets_parser(parser)
####################################
######### METRICS ##################
parser = metric_parser(parser)
####################################
######### General Training #########
parser = general_training_parser(parser)
####################################
######### Training GeoPath Network ####
parser = geopath_network_parser(parser)
####################################
######### Training Flow Network ####
parser = flow_network_parser(parser)
####################################
parser = growth_network_parser(parser)
return parser.parse_args()
def datasets_parser(parser):
parser.add_argument("--dim", type=int, default=3, help="Dimension of data")
parser.add_argument(
"--data_type",
type=str,
default="lidar",
help="Type of data, now wither scrna or one of toys",
)
parser.add_argument(
"--data_path",
type=str,
default="./data/rainier2-thin.las",
help="lidar data path",
)
parser.add_argument(
"--data_name",
type=str,
default="lidar",
help="Path to the dataset",
)
parser.add_argument(
"--whiten",
action=argparse.BooleanOptionalAction,
default=True,
help="Whiten the data",
)
return parser
def image_datasets_parser(parser):
parser.add_argument(
"--image_size",
type=int,
default=128,
help="Size of the image",
)
parser.add_argument(
"--x0_label",
type=str,
default="dog",
help="Label for x0",
)
parser.add_argument(
"--x1_label",
type=str,
default="cat",
help="Label for x1",
)
return parser
def metric_parser(parser):
parser.add_argument(
"--branchsbm",
action=argparse.BooleanOptionalAction,
default=True,
help="If branched SBM",
)
parser.add_argument(
"--n_centers",
type=int,
default=100,
help="Number of centers for RBF network",
)
parser.add_argument(
"--kappa",
type=float,
default=1.0,
help="Kappa parameter for RBF network",
)
parser.add_argument(
"--rho",
type=float,
default=0.001,
help="Rho parameter in Riemanian Velocity Calculation",
)
parser.add_argument(
"--velocity_metric",
type=str,
default="rbf",
help="Metric for velocity calculation",
)
parser.add_argument(
"--gammas",
nargs="+",
type=float,
default=[0.2, 0.2],
help="Gamma parameter in Riemanian Velocity Calculation",
)
parser.add_argument(
"--metric_epochs",
type=int,
default=50,
help="Number of epochs for metric learning",
)
parser.add_argument(
"--metric_patience",
type=int,
default=5,
help="Patience for metric learning",
)
parser.add_argument(
"--metric_lr",
type=float,
default=1e-2,
help="Learning rate for metric learning",
)
parser.add_argument(
"--alpha_metric",
type=float,
default=1.0,
help="Alpha parameter for metric learning",
)
return parser
def general_training_parser(parser):
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training"
)
parser.add_argument(
"--optimal_transport_method",
type=str,
default="exact",
help="Use optimal transport in CFM training",
)
parser.add_argument(
"--ema_decay",
type=float,
default=None,
help="Decay for EMA",
)
parser.add_argument(
"--split_ratios",
nargs=2,
type=float,
default=[0.9, 0.1],
help="Split ratios for training/validation data in CFM training",
)
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
parser.add_argument(
"--accelerator", type=str, default="cpu", help="Training accelerator"
)
parser.add_argument(
"--sim_num_steps",
type=int,
default=1000,
help="Number of steps in simulation",
)
return parser
def geopath_network_parser(parser):
parser.add_argument(
"--manifold",
action=argparse.BooleanOptionalAction,
default=True,
help="If use data manifold metric",
)
parser.add_argument(
"--patience_geopath",
type=int,
default=50,
help="Patience for training geopath model",
)
parser.add_argument(
"--hidden_dims_geopath",
nargs="+",
type=int,
default=[64, 64, 64],
help="Dimensions of hidden layers for GeoPath model training",
)
parser.add_argument(
"--time_geopath",
action=argparse.BooleanOptionalAction,
default=False,
help="Use time in GeoPath model",
)
parser.add_argument(
"--activation_geopath",
type=str,
default="selu",
help="Activation function for GeoPath",
)
parser.add_argument(
"--geopath_optimizer",
type=str,
default="adam",
help="Optimizer for GeoPath training",
)
parser.add_argument(
"--geopath_lr",
type=float,
default=1e-4,
help="Learning rate for GeoPath training",
)
parser.add_argument(
"--geopath_weight_decay",
type=float,
default=1e-5,
help="Weight decay for GeoPath training",
)
return parser
def flow_network_parser(parser):
parser.add_argument(
"--sigma", type=float, default=0.1, help="Sigma parameter for CFM (variance)"
)
parser.add_argument(
"--patience",
type=int,
default=5,
help="Patience for early stopping in CFM training",
)
parser.add_argument(
"--hidden_dims_flow",
nargs="+",
type=int,
default=[64, 64, 64],
help="Dimensions of hidden layers for CFM training",
)
parser.add_argument(
"--check_val_every_n_epoch",
type=int,
default=10,
help="Check validation every N epochs during CFM training",
)
parser.add_argument(
"--activation_flow",
type=str,
default="selu",
help="Activation function for CFM",
)
parser.add_argument(
"--flow_optimizer",
type=str,
default="adamw",
help="Optimizer for GeoPath training",
)
parser.add_argument(
"--flow_lr",
type=float,
default=1e-3,
help="Learning rate for GeoPath training",
)
parser.add_argument(
"--flow_weight_decay",
type=float,
default=1e-5,
help="Weight decay for GeoPath training",
)
return parser
def growth_network_parser(parser):
parser.add_argument(
"--patience_growth",
type=int,
default=5,
help="Patience for early stopping in CFM training",
)
parser.add_argument(
"--time_growth",
action=argparse.BooleanOptionalAction,
default=False,
help="Use time in GeoPath model",
)
parser.add_argument(
"--hidden_dims_growth",
nargs="+",
type=int,
default=[64, 64, 64],
help="Dimensions of hidden layers for growth net training",
)
parser.add_argument(
"--activation_growth",
type=str,
default="tanh",
help="Activation function for CFM",
)
parser.add_argument(
"--growth_optimizer",
type=str,
default="adamw",
help="Optimizer for GeoPath training",
)
parser.add_argument(
"--growth_lr",
type=float,
default=1e-3,
help="Learning rate for GeoPath training",
)
parser.add_argument(
"--growth_weight_decay",
type=float,
default=1e-5,
help="Weight decay for GeoPath training",
)
parser.add_argument(
"--lambda_energy",
type=float,
default=1.0,
help="Weight for energy loss",
)
parser.add_argument(
"--lambda_mass",
type=float,
default=100.0,
help="Weight for mass loss",
)
parser.add_argument(
"--lambda_match",
type=float,
default=1000.0,
help="Weight for matching loss",
)
parser.add_argument(
"--lambda_recons",
type=float,
default=1.0,
help="Weight for reconstruction loss",
)
return parser