Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from pathlib import Path | |
| from absl import flags | |
| from absl import app | |
| from ml_collections import config_flags | |
| from train import train | |
| FLAGS = flags.FLAGS | |
| config_flags.DEFINE_config_file( | |
| "config", None, "Training configuration.", lock_config=False) | |
| flags.mark_flags_as_required(["config"]) | |
| flags.DEFINE_string("workdir", None, "Work unit directory.") | |
| flags.DEFINE_string("workdir_base", None, "Base directory for workdir. If not provided, uses default path.") | |
| flags.DEFINE_string("vae_pretrained_path", None, "Path to pretrained VAE checkpoint.") | |
| flags.DEFINE_string("model_pretrained_path", None, "Path to pretrained model checkpoint.") | |
| flags.DEFINE_string("fid_stat_path", None, "Path to FID statistics file.") | |
| flags.DEFINE_string("inception_ckpt_path", None, "Path to Inception checkpoint.") | |
| flags.DEFINE_string("sample_path", None, "Path to save samples.") | |
| flags.DEFINE_string("train_tar_pattern", None, "Training tar pattern for WebDataset.") | |
| flags.DEFINE_string("test_tar_pattern", None, "Test tar pattern for WebDataset.") | |
| flags.DEFINE_string("vis_image_root", None, "Path to visualization images root.") | |
| flags.DEFINE_string("resume_ckpt_root", None, "Path to checkpoint root directory for resuming. If not provided, uses workdir/ckpts.") | |
| # WandB parameters | |
| flags.DEFINE_string("wandb_project", None, "WandB project name. If not provided, uses config.wandb_project or default naming.") | |
| flags.DEFINE_enum("wandb_mode", None, ["online", "offline", "disabled"], "WandB mode: online (sync to cloud), offline (local only), or disabled.") | |
| # Training parameters | |
| flags.DEFINE_integer("n_steps", None, "Total training iterations.") | |
| flags.DEFINE_integer("batch_size", None, "Overall batch size across ALL gpus.") | |
| flags.DEFINE_integer("log_interval", None, "Iteration interval for logging.") | |
| flags.DEFINE_integer("eval_interval", None, "Iteration interval for visual testing.") | |
| flags.DEFINE_integer("save_interval", None, "Iteration interval for saving checkpoints.") | |
| flags.DEFINE_integer("n_samples_eval", None, "Number of samples for evaluation.") | |
| # Dataset parameters | |
| flags.DEFINE_string("dataset_name", None, "Dataset name.") | |
| flags.DEFINE_string("task", None, "Task name.") | |
| flags.DEFINE_integer("resolution", None, "Dataset resolution.") | |
| flags.DEFINE_integer("shuffle_buffer", None, "Shuffle buffer size for WebDataset.") | |
| flags.DEFINE_boolean("resampled", None, "Whether to resample WebDataset.") | |
| flags.DEFINE_boolean("split_data_by_node", None, "Whether to split data by node.") | |
| flags.DEFINE_integer("estimated_samples_per_shard", None, "Estimated samples per shard.") | |
| flags.DEFINE_string("sampling_weights", None, "Sampling weights for multiple tar patterns (format: '0.7,0.3').") | |
| # Sample parameters | |
| flags.DEFINE_integer("sample_steps", None, "Sample steps during inference/testing.") | |
| flags.DEFINE_integer("n_samples", None, "Number of samples for testing.") | |
| flags.DEFINE_integer("mini_batch_size", None, "Batch size for testing.") | |
| flags.DEFINE_integer("scale", None, "CFG scale.") | |
| # Optimizer parameters | |
| flags.DEFINE_string("optimizer_name", None, "Optimizer name.") | |
| flags.DEFINE_float("lr", None, "Learning rate.") | |
| flags.DEFINE_float("weight_decay", None, "Weight decay.") | |
| flags.DEFINE_string("betas", None, "Betas for optimizer (format: '0.9,0.9').") | |
| flags.DEFINE_enum("adamw_impl", None, ["torch", "bitsandbytes", "AdamW", "AdamW8bit"], "Select AdamW backend.") | |
| # DataLoader parameters | |
| flags.DEFINE_integer("num_workers", None, "Number of workers for DataLoader.") | |
| # Model parameters | |
| flags.DEFINE_boolean("use_cross_attention", None, "Whether to use cross attention in the first stage config.") | |
| def get_config_name(): | |
| argv = sys.argv | |
| for i in range(1, len(argv)): | |
| if argv[i].startswith('--config='): | |
| return Path(argv[i].split('=')[-1]).stem | |
| def get_hparams(): | |
| argv = sys.argv | |
| lst = [] | |
| for i in range(1, len(argv)): | |
| assert '=' in argv[i] | |
| if argv[i].startswith('--config.') and not argv[i].startswith('--config.dataset.path'): | |
| hparam, val = argv[i].split('=') | |
| hparam = hparam.split('.')[-1] | |
| if hparam.endswith('path'): | |
| val = Path(val).stem | |
| lst.append(f'{hparam}={val}') | |
| hparams = '-'.join(lst) | |
| if hparams == '': | |
| hparams = 'default' | |
| return hparams | |
| def main(argv): | |
| config = FLAGS.config | |
| config.config_name = get_config_name() | |
| config.hparams = get_hparams() | |
| if FLAGS.workdir: | |
| config.workdir = FLAGS.workdir | |
| else: | |
| default_workdir_base = '/path/to/workdir_base' | |
| workdir_base = FLAGS.workdir_base or default_workdir_base | |
| config.workdir = os.path.join(workdir_base, config.config_name, config.hparams) | |
| config.ckpt_root = os.path.join(config.workdir, 'ckpts') | |
| # resume_ckpt_root is used for resuming; if specified, the specified path is used; otherwise, ckpt_root is used. | |
| if FLAGS.resume_ckpt_root: | |
| config.resume_ckpt_root = FLAGS.resume_ckpt_root | |
| else: | |
| config.resume_ckpt_root = config.ckpt_root | |
| config.sample_dir = os.path.join(config.workdir, 'samples') | |
| # WandB | |
| if FLAGS.wandb_project: | |
| config.wandb_project = FLAGS.wandb_project | |
| if FLAGS.wandb_mode: | |
| config.wandb_mode = FLAGS.wandb_mode | |
| if FLAGS.vae_pretrained_path: | |
| config.autoencoder.pretrained_path = FLAGS.vae_pretrained_path | |
| if FLAGS.model_pretrained_path: | |
| config.pretrained_path = FLAGS.model_pretrained_path | |
| if FLAGS.fid_stat_path: | |
| config.fid_stat_path = FLAGS.fid_stat_path | |
| if FLAGS.inception_ckpt_path: | |
| config.inception_ckpt_path = FLAGS.inception_ckpt_path | |
| if FLAGS.sample_path: | |
| config.sample.path = FLAGS.sample_path | |
| if FLAGS.train_tar_pattern: | |
| config.dataset.train_tar_pattern = FLAGS.train_tar_pattern | |
| if FLAGS.test_tar_pattern: | |
| config.dataset.test_tar_pattern = FLAGS.test_tar_pattern | |
| if FLAGS.vis_image_root: | |
| config.dataset.vis_image_root = FLAGS.vis_image_root | |
| # Training parameters | |
| if FLAGS.n_steps is not None: | |
| config.train.n_steps = FLAGS.n_steps | |
| if FLAGS.batch_size is not None: | |
| config.train.batch_size = FLAGS.batch_size | |
| if FLAGS.log_interval is not None: | |
| config.train.log_interval = FLAGS.log_interval | |
| if FLAGS.eval_interval is not None: | |
| config.train.eval_interval = FLAGS.eval_interval | |
| if FLAGS.save_interval is not None: | |
| config.train.save_interval = FLAGS.save_interval | |
| if FLAGS.n_samples_eval is not None: | |
| config.train.n_samples_eval = FLAGS.n_samples_eval | |
| # Dataset parameters | |
| if FLAGS.dataset_name is not None: | |
| config.dataset.name = FLAGS.dataset_name | |
| if FLAGS.task is not None: | |
| config.dataset.task = FLAGS.task | |
| if FLAGS.resolution is not None: | |
| config.dataset.resolution = FLAGS.resolution | |
| if FLAGS.shuffle_buffer is not None: | |
| config.dataset.shuffle_buffer = FLAGS.shuffle_buffer | |
| if FLAGS.resampled is not None: | |
| config.dataset.resampled = FLAGS.resampled | |
| if FLAGS.split_data_by_node is not None: | |
| config.dataset.split_data_by_node = FLAGS.split_data_by_node | |
| if FLAGS.estimated_samples_per_shard is not None: | |
| config.dataset.estimated_samples_per_shard = FLAGS.estimated_samples_per_shard | |
| if FLAGS.sampling_weights is not None: | |
| sampling_weights_values = [float(x.strip()) for x in FLAGS.sampling_weights.split(',')] | |
| config.dataset.sampling_weights = sampling_weights_values | |
| # Sample parameters | |
| if FLAGS.sample_steps is not None: | |
| config.sample.sample_steps = FLAGS.sample_steps | |
| if FLAGS.n_samples is not None: | |
| config.sample.n_samples = FLAGS.n_samples | |
| if FLAGS.mini_batch_size is not None: | |
| config.sample.mini_batch_size = FLAGS.mini_batch_size | |
| if FLAGS.scale is not None: | |
| config.sample.scale = FLAGS.scale | |
| # Optimizer parameters | |
| if FLAGS.optimizer_name is not None: | |
| config.optimizer.name = FLAGS.optimizer_name | |
| if FLAGS.lr is not None: | |
| config.optimizer.lr = FLAGS.lr | |
| if FLAGS.weight_decay is not None: | |
| config.optimizer.weight_decay = FLAGS.weight_decay | |
| if FLAGS.betas is not None: | |
| betas_values = [float(x.strip()) for x in FLAGS.betas.split(',')] | |
| config.optimizer.betas = tuple(betas_values) | |
| if FLAGS.adamw_impl is not None: | |
| config.optimizer.adamw_impl = FLAGS.adamw_impl | |
| # DataLoader parameters | |
| if FLAGS.num_workers is not None: | |
| config.num_workers = FLAGS.num_workers | |
| # Model parameters | |
| if FLAGS.use_cross_attention is not None: | |
| if hasattr(config.nnet.model_args, 'stage_configs') and len(config.nnet.model_args.stage_configs) > 0: | |
| config.nnet.model_args.stage_configs[0].use_cross_attention = FLAGS.use_cross_attention | |
| train(config) | |
| if __name__ == "__main__": | |
| app.run(main) |