Spaces:
Configuration error
Configuration error
| import argparse | |
| import os | |
| import torch | |
| import json | |
| import warnings | |
| import omegaconf | |
| from omegaconf import OmegaConf | |
| from sat.helpers import print_rank0 | |
| from sat import mpu | |
| from sat.arguments import set_random_seed | |
| from sat.arguments import add_training_args, add_evaluation_args, add_data_args | |
| import torch.distributed | |
| def add_model_config_args(parser): | |
| """Model arguments""" | |
| group = parser.add_argument_group("model", "model configuration") | |
| group.add_argument("--base", type=str, nargs="*", help="config for input and saving") | |
| group.add_argument( | |
| "--model-parallel-size", type=int, default=1, help="size of the model parallel. only use if you are an expert." | |
| ) | |
| group.add_argument("--force-pretrain", action="store_true") | |
| group.add_argument("--device", type=int, default=-1) | |
| group.add_argument("--debug", action="store_true") | |
| group.add_argument("--log-image", type=bool, default=True) | |
| return parser | |
| def add_sampling_config_args(parser): | |
| """Sampling configurations""" | |
| group = parser.add_argument_group("sampling", "Sampling Configurations") | |
| group.add_argument("--output-dir", type=str, default="samples") | |
| group.add_argument("--input-dir", type=str, default=None) | |
| group.add_argument("--input-type", type=str, default="cli") | |
| group.add_argument("--input-file", type=str, default="input.txt") | |
| group.add_argument("--final-size", type=int, default=2048) | |
| group.add_argument("--sdedit", action="store_true") | |
| group.add_argument("--grid-num-rows", type=int, default=1) | |
| group.add_argument("--force-inference", action="store_true") | |
| group.add_argument("--lcm_steps", type=int, default=None) | |
| group.add_argument("--sampling-num-frames", type=int, default=32) | |
| group.add_argument("--sampling-fps", type=int, default=8) | |
| group.add_argument("--only-save-latents", type=bool, default=False) | |
| group.add_argument("--only-log-video-latents", type=bool, default=False) | |
| group.add_argument("--latent-channels", type=int, default=32) | |
| group.add_argument("--image2video", action="store_true") | |
| return parser | |
| def get_args(args_list=None, parser=None): | |
| """Parse all the args.""" | |
| if parser is None: | |
| parser = argparse.ArgumentParser(description="sat") | |
| else: | |
| assert isinstance(parser, argparse.ArgumentParser) | |
| parser = add_model_config_args(parser) | |
| parser = add_sampling_config_args(parser) | |
| parser = add_training_args(parser) | |
| parser = add_evaluation_args(parser) | |
| parser = add_data_args(parser) | |
| import deepspeed | |
| parser = deepspeed.add_config_arguments(parser) | |
| args = parser.parse_args(args_list) | |
| args = process_config_to_args(args) | |
| if not args.train_data: | |
| print_rank0("No training data specified", level="WARNING") | |
| assert (args.train_iters is None) or (args.epochs is None), "only one of train_iters and epochs should be set." | |
| if args.train_iters is None and args.epochs is None: | |
| args.train_iters = 10000 # default 10k iters | |
| print_rank0("No train_iters (recommended) or epochs specified, use default 10k iters.", level="WARNING") | |
| args.cuda = torch.cuda.is_available() | |
| args.rank = int(os.getenv("RANK", "0")) | |
| args.world_size = int(os.getenv("WORLD_SIZE", "1")) | |
| if args.local_rank is None: | |
| args.local_rank = int(os.getenv("LOCAL_RANK", "0")) # torchrun | |
| if args.device == -1: | |
| if torch.cuda.device_count() == 0: | |
| args.device = "cpu" | |
| elif args.local_rank is not None: | |
| args.device = args.local_rank | |
| else: | |
| args.device = args.rank % torch.cuda.device_count() | |
| if args.local_rank != args.device and args.mode != "inference": | |
| raise ValueError( | |
| "LOCAL_RANK (default 0) and args.device inconsistent. " | |
| "This can only happens in inference mode. " | |
| "Please use CUDA_VISIBLE_DEVICES=x for single-GPU training. " | |
| ) | |
| if args.rank == 0: | |
| print_rank0("using world size: {}".format(args.world_size)) | |
| if args.train_data_weights is not None: | |
| assert len(args.train_data_weights) == len(args.train_data) | |
| if args.mode != "inference": # training with deepspeed | |
| args.deepspeed = True | |
| if args.deepspeed_config is None: # not specified | |
| deepspeed_config_path = os.path.join( | |
| os.path.dirname(__file__), "training", f"deepspeed_zero{args.zero_stage}.json" | |
| ) | |
| with open(deepspeed_config_path) as file: | |
| args.deepspeed_config = json.load(file) | |
| override_deepspeed_config = True | |
| else: | |
| override_deepspeed_config = False | |
| assert not (args.fp16 and args.bf16), "cannot specify both fp16 and bf16." | |
| if args.zero_stage > 0 and not args.fp16 and not args.bf16: | |
| print_rank0("Automatically set fp16=True to use ZeRO.") | |
| args.fp16 = True | |
| args.bf16 = False | |
| if args.deepspeed: | |
| if args.checkpoint_activations: | |
| args.deepspeed_activation_checkpointing = True | |
| else: | |
| args.deepspeed_activation_checkpointing = False | |
| if args.deepspeed_config is not None: | |
| deepspeed_config = args.deepspeed_config | |
| if override_deepspeed_config: # not specify deepspeed_config, use args | |
| if args.fp16: | |
| deepspeed_config["fp16"]["enabled"] = True | |
| elif args.bf16: | |
| deepspeed_config["bf16"]["enabled"] = True | |
| deepspeed_config["fp16"]["enabled"] = False | |
| else: | |
| deepspeed_config["fp16"]["enabled"] = False | |
| deepspeed_config["train_micro_batch_size_per_gpu"] = args.batch_size | |
| deepspeed_config["gradient_accumulation_steps"] = args.gradient_accumulation_steps | |
| optimizer_params_config = deepspeed_config["optimizer"]["params"] | |
| optimizer_params_config["lr"] = args.lr | |
| optimizer_params_config["weight_decay"] = args.weight_decay | |
| else: # override args with values in deepspeed_config | |
| if args.rank == 0: | |
| print_rank0("Will override arguments with manually specified deepspeed_config!") | |
| if "fp16" in deepspeed_config and deepspeed_config["fp16"]["enabled"]: | |
| args.fp16 = True | |
| else: | |
| args.fp16 = False | |
| if "bf16" in deepspeed_config and deepspeed_config["bf16"]["enabled"]: | |
| args.bf16 = True | |
| else: | |
| args.bf16 = False | |
| if "train_micro_batch_size_per_gpu" in deepspeed_config: | |
| args.batch_size = deepspeed_config["train_micro_batch_size_per_gpu"] | |
| if "gradient_accumulation_steps" in deepspeed_config: | |
| args.gradient_accumulation_steps = deepspeed_config["gradient_accumulation_steps"] | |
| else: | |
| args.gradient_accumulation_steps = None | |
| if "optimizer" in deepspeed_config: | |
| optimizer_params_config = deepspeed_config["optimizer"].get("params", {}) | |
| args.lr = optimizer_params_config.get("lr", args.lr) | |
| args.weight_decay = optimizer_params_config.get("weight_decay", args.weight_decay) | |
| args.deepspeed_config = deepspeed_config | |
| # initialize distributed and random seed because it always seems to be necessary. | |
| initialize_distributed(args) | |
| args.seed = args.seed + mpu.get_data_parallel_rank() | |
| set_random_seed(args.seed) | |
| return args | |
| def initialize_distributed(args): | |
| """Initialize torch.distributed.""" | |
| if torch.distributed.is_initialized(): | |
| if mpu.model_parallel_is_initialized(): | |
| if args.model_parallel_size != mpu.get_model_parallel_world_size(): | |
| raise ValueError( | |
| "model_parallel_size is inconsistent with prior configuration." | |
| "We currently do not support changing model_parallel_size." | |
| ) | |
| return False | |
| else: | |
| if args.model_parallel_size > 1: | |
| warnings.warn( | |
| "model_parallel_size > 1 but torch.distributed is not initialized via SAT." | |
| "Please carefully make sure the correctness on your own." | |
| ) | |
| mpu.initialize_model_parallel(args.model_parallel_size) | |
| return True | |
| # the automatic assignment of devices has been moved to arguments.py | |
| if args.device == "cpu": | |
| pass | |
| else: | |
| torch.cuda.set_device(args.device) | |
| # Call the init process | |
| init_method = "tcp://" | |
| args.master_ip = os.getenv("MASTER_ADDR", "localhost") | |
| if args.world_size == 1: | |
| from sat.helpers import get_free_port | |
| default_master_port = str(get_free_port()) | |
| else: | |
| default_master_port = "6000" | |
| args.master_port = os.getenv("MASTER_PORT", default_master_port) | |
| init_method += args.master_ip + ":" + args.master_port | |
| torch.distributed.init_process_group( | |
| backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method | |
| ) | |
| # Set the model-parallel / data-parallel communicators. | |
| mpu.initialize_model_parallel(args.model_parallel_size) | |
| # Set vae context parallel group equal to model parallel group | |
| from sgm.util import set_context_parallel_group, initialize_context_parallel | |
| if args.model_parallel_size <= 2: | |
| set_context_parallel_group(args.model_parallel_size, mpu.get_model_parallel_group()) | |
| else: | |
| initialize_context_parallel(2) | |
| # mpu.initialize_model_parallel(1) | |
| # Optional DeepSpeed Activation Checkpointing Features | |
| if args.deepspeed: | |
| import deepspeed | |
| deepspeed.init_distributed( | |
| dist_backend=args.distributed_backend, world_size=args.world_size, rank=args.rank, init_method=init_method | |
| ) | |
| # # It seems that it has no negative influence to configure it even without using checkpointing. | |
| # deepspeed.checkpointing.configure(mpu, deepspeed_config=args.deepspeed_config, num_checkpoints=args.num_layers) | |
| else: | |
| # in model-only mode, we don't want to init deepspeed, but we still need to init the rng tracker for model_parallel, just because we save the seed by default when dropout. | |
| try: | |
| import deepspeed | |
| from deepspeed.runtime.activation_checkpointing.checkpointing import ( | |
| _CUDA_RNG_STATE_TRACKER, | |
| _MODEL_PARALLEL_RNG_TRACKER_NAME, | |
| ) | |
| _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, 1) # default seed 1 | |
| except Exception as e: | |
| from sat.helpers import print_rank0 | |
| print_rank0(str(e), level="DEBUG") | |
| return True | |
| def process_config_to_args(args): | |
| """Fetch args from only --base""" | |
| configs = [OmegaConf.load(cfg) for cfg in args.base] | |
| config = OmegaConf.merge(*configs) | |
| args_config = config.pop("args", OmegaConf.create()) | |
| for key in args_config: | |
| if isinstance(args_config[key], omegaconf.DictConfig) or isinstance(args_config[key], omegaconf.ListConfig): | |
| arg = OmegaConf.to_object(args_config[key]) | |
| else: | |
| arg = args_config[key] | |
| if hasattr(args, key): | |
| setattr(args, key, arg) | |
| if "model" in config: | |
| model_config = config.pop("model", OmegaConf.create()) | |
| args.model_config = model_config | |
| if "deepspeed" in config: | |
| deepspeed_config = config.pop("deepspeed", OmegaConf.create()) | |
| args.deepspeed_config = OmegaConf.to_object(deepspeed_config) | |
| if "data" in config: | |
| data_config = config.pop("data", OmegaConf.create()) | |
| args.data_config = data_config | |
| return args | |