| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| """Train a GAN using the techniques described in the paper |
| "Training Generative Adversarial Networks with Limited Data".""" |
| import sys |
| import os |
|
|
| sys.path.insert(1, os.path.join(sys.path[0], "..")) |
| import click |
| import re |
| import json |
| import tempfile |
| import torch |
| import dnnlib |
|
|
| import numpy as np |
|
|
| import parser |
|
|
| from training import training_loop |
| from metrics import metric_main |
| from torch_utils import training_stats |
| from torch_utils import custom_ops |
|
|
|
|
| |
|
|
|
|
| class UserError(Exception): |
| pass |
|
|
|
|
| |
|
|
|
|
| def setup_training_loop_kwargs( |
| |
| exp_name=None, |
| slurm=None, |
| gpus=None, |
| nodes=None, |
| snap=None, |
| metrics=None, |
| seed=None, |
| |
| data=None, |
| class_cond=None, |
| subset=None, |
| mirror=None, |
| |
| instance_cond=None, |
| feature_augmentation=None, |
| root_feats=None, |
| root_nns=None, |
| label_dim=None, |
| |
| cfg=None, |
| lrate=None, |
| gamma=None, |
| kimg=None, |
| batch=None, |
| num_channel_g=None, |
| num_channel_d=None, |
| channel_max_g=None, |
| channel_max_d=None, |
| hidden_dim_c=None, |
| hidden_dim_h=None, |
| es_patience=None, |
| |
| aug=None, |
| p=None, |
| target=None, |
| augpipe=None, |
| |
| resume=None, |
| freezed=None, |
| |
| fp32=None, |
| nhwc=None, |
| allow_tf32=None, |
| nobench=None, |
| workers=None, |
| **kwargs, |
| ): |
| args = dnnlib.EasyDict() |
|
|
| |
| |
| |
|
|
| if gpus is None: |
| gpus = 1 |
| assert isinstance(gpus, int) |
| if not (gpus >= 1 and gpus & (gpus - 1) == 0): |
| raise UserError("--gpus must be a power of two") |
| args.num_gpus = gpus * nodes |
|
|
| if snap is None: |
| snap = 50 |
| assert isinstance(snap, int) |
| if snap < 1: |
| raise UserError("--snap must be at least 1") |
| args.image_snapshot_ticks = snap |
| args.network_snapshot_ticks = snap |
| args.es_patience = es_patience |
|
|
| if metrics is None: |
| metrics = ["fid50k_full"] |
| assert isinstance(metrics, list) |
| if not all(metric_main.is_valid_metric(metric) for metric in metrics): |
| raise UserError( |
| "\n".join( |
| ["--metrics can only contain the following values:"] |
| + metric_main.list_valid_metrics() |
| ) |
| ) |
| args.metrics = metrics |
|
|
| if seed is None: |
| seed = 0 |
| assert isinstance(seed, int) |
| args.random_seed = seed |
|
|
| |
| |
| |
|
|
| assert data is not None |
| assert isinstance(data, str) |
|
|
| class_name = "data_utils.datasets_common.ILSVRC_HDF5_feats" |
| args.class_cond = class_cond |
| args.instance_cond = instance_cond |
|
|
| if mirror is None: |
| mirror = False |
| assert isinstance(mirror, bool) |
|
|
| args.training_set_kwargs = dnnlib.EasyDict( |
| class_name=class_name, |
| root=data, |
| max_size=None, |
| xflip=False, |
| load_labels=class_cond, |
| load_features=instance_cond, |
| root_feats=root_feats, |
| root_nns=root_nns, |
| transform=None, |
| label_dim=label_dim, |
| feature_dim=2048, |
| apply_norm=False, |
| label_onehot=True, |
| feature_augmentation=feature_augmentation, |
| ) |
| args.data_loader_kwargs = dnnlib.EasyDict( |
| pin_memory=True, num_workers=3, prefetch_factor=2 |
| ) |
| try: |
| training_set = dnnlib.util.construct_class_by_name( |
| **args.training_set_kwargs |
| ) |
| args.training_set_kwargs.resolution = ( |
| training_set.resolution |
| ) |
| args.training_set_kwargs.load_labels = class_cond |
| args.training_set_kwargs.max_size = len( |
| training_set |
| ) |
| desc = os.path.splitext(os.path.basename(data))[0] |
| del training_set |
| except IOError as err: |
| raise UserError(f"--data: {err}") |
|
|
| if mirror: |
| desc += "-mirror" |
| args.training_set_kwargs.xflip = True |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if subset is not None: |
| assert isinstance(subset, int) |
| if not 1 <= subset <= args.training_set_kwargs.max_size: |
| raise UserError( |
| f"--subset must be between 1 and {args.training_set_kwargs.max_size}" |
| ) |
| desc += f"-subset{subset}" |
| if subset < args.training_set_kwargs.max_size: |
| args.training_set_kwargs.max_size = subset |
| args.training_set_kwargs.random_seed = args.random_seed |
|
|
| |
| |
| |
|
|
| if cfg is None: |
| cfg = "auto" |
| assert isinstance(cfg, str) |
| desc += f"-{cfg}" |
|
|
| cfg_specs = { |
| "auto": dict( |
| ref_gpus=-1, |
| kimg=25000, |
| mb=-1, |
| mbstd=-1, |
| fmaps=-1, |
| lrate=-1, |
| gamma=-1, |
| ema=-1, |
| ramp=0.05, |
| map=2, |
| ), |
| "stylegan2": dict( |
| ref_gpus=8, |
| kimg=25000, |
| mb=32, |
| mbstd=4, |
| fmaps=1, |
| lrate=0.002, |
| gamma=10, |
| ema=10, |
| ramp=None, |
| map=8, |
| ), |
| "paper256": dict( |
| ref_gpus=8, |
| kimg=25000, |
| mb=64, |
| mbstd=8, |
| fmaps=0.5, |
| lrate=0.0025, |
| gamma=1, |
| ema=20, |
| ramp=None, |
| map=8, |
| ), |
| "paper512": dict( |
| ref_gpus=8, |
| kimg=25000, |
| mb=64, |
| mbstd=8, |
| fmaps=1, |
| lrate=0.0025, |
| gamma=0.5, |
| ema=20, |
| ramp=None, |
| map=8, |
| ), |
| "paper1024": dict( |
| ref_gpus=8, |
| kimg=25000, |
| mb=32, |
| mbstd=4, |
| fmaps=1, |
| lrate=0.002, |
| gamma=2, |
| ema=10, |
| ramp=None, |
| map=8, |
| ), |
| "cifar": dict( |
| ref_gpus=2, |
| kimg=100000, |
| mb=64, |
| mbstd=32, |
| fmaps=1, |
| lrate=0.0025, |
| gamma=0.01, |
| ema=500, |
| ramp=0.05, |
| map=2, |
| ), |
| } |
|
|
| assert cfg in cfg_specs |
| spec = dnnlib.EasyDict(cfg_specs[cfg]) |
| if cfg == "auto": |
| desc += f"{gpus:d}" |
| spec.ref_gpus = args.num_gpus |
| res = args.training_set_kwargs.resolution |
| spec.mb = max( |
| min(args.num_gpus * min(4096 // res, 32), 64), args.num_gpus |
| ) |
| spec.mbstd = min( |
| spec.mb // args.num_gpus, 4 |
| ) |
| spec.fmaps = 1 if res >= 512 else 0.5 |
| spec.lrate = 0.002 if res >= 1024 else 0.0025 |
| spec.gamma = 0.0002 * (res ** 2) / spec.mb |
| spec.ema = spec.mb * 10 / 32 |
|
|
| args.G_kwargs = dnnlib.EasyDict( |
| class_name="training.networks.Generator", |
| z_dim=512, |
| w_dim=512, |
| mapping_kwargs=dnnlib.EasyDict(), |
| synthesis_kwargs=dnnlib.EasyDict(), |
| ) |
| args.D_kwargs = dnnlib.EasyDict( |
| class_name="training.networks.Discriminator", |
| block_kwargs=dnnlib.EasyDict(), |
| mapping_kwargs=dnnlib.EasyDict(), |
| epilogue_kwargs=dnnlib.EasyDict(), |
| ) |
| args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int( |
| spec.fmaps * 32768 |
| ) |
| args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 |
| args.G_kwargs.mapping_kwargs.num_layers = spec.map |
| if hidden_dim_c is not None: |
| args.G_kwargs.mapping_kwargs.embed_features = hidden_dim_c |
| args.D_kwargs.mapping_kwargs.embed_features = hidden_dim_c |
| if hidden_dim_h is not None: |
| args.G_kwargs.mapping_kwargs.embed_features_feat = hidden_dim_h |
| args.D_kwargs.mapping_kwargs.embed_features_feat = hidden_dim_h |
| args.G_kwargs.synthesis_kwargs.num_fp16_res = ( |
| args.D_kwargs.num_fp16_res |
| ) = 4 |
| args.G_kwargs.synthesis_kwargs.conv_clamp = ( |
| args.D_kwargs.conv_clamp |
| ) = 256 |
| args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd |
|
|
| args.exp_name = exp_name |
| if num_channel_d is not None: |
| args.D_kwargs.channel_base = num_channel_d |
| if channel_max_d is not None: |
| args.D_kwargs.channel_max = channel_max_d |
| if num_channel_g is not None: |
| args.G_kwargs.synthesis_kwargs.channel_base = num_channel_g |
| if channel_max_g is not None: |
| args.G_kwargs.synthesis_kwargs.channel_max = channel_max_g |
|
|
| if lrate is not None: |
| spec.lrate = lrate |
|
|
| args.G_opt_kwargs = dnnlib.EasyDict( |
| class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8 |
| ) |
| args.D_opt_kwargs = dnnlib.EasyDict( |
| class_name="torch.optim.Adam", lr=spec.lrate, betas=[0, 0.99], eps=1e-8 |
| ) |
| args.loss_kwargs = dnnlib.EasyDict( |
| class_name="training.loss.StyleGAN2Loss", r1_gamma=spec.gamma |
| ) |
|
|
| args.total_kimg = spec.kimg |
| args.batch_size = spec.mb |
| args.batch_gpu = spec.mb // spec.ref_gpus |
| args.ema_kimg = spec.ema |
| args.ema_rampup = spec.ramp |
|
|
| if cfg == "cifar": |
| args.loss_kwargs.pl_weight = 0 |
| args.loss_kwargs.style_mixing_prob = 0 |
| args.D_kwargs.architecture = "orig" |
|
|
| if gamma is not None: |
| assert isinstance(gamma, float) |
| if not gamma >= 0: |
| raise UserError("--gamma must be non-negative") |
| desc += f"-gamma{gamma:g}" |
| args.loss_kwargs.r1_gamma = gamma |
|
|
| if kimg is not None: |
| assert isinstance(kimg, int) |
| if not kimg >= 1: |
| raise UserError("--kimg must be at least 1") |
| desc += f"-kimg{kimg:d}" |
| args.total_kimg = kimg |
|
|
| if batch is not None: |
| assert isinstance(batch, int) |
| if not (batch >= 1 and batch % args.num_gpus == 0): |
| raise UserError( |
| "--batch must be at least 1 and divisible by --gpus and --nodes" |
| ) |
| desc += f"-batch{batch}" |
| args.batch_size = batch |
| args.batch_gpu = batch // (args.num_gpus) |
| args.slurm = slurm |
|
|
| |
| |
| |
|
|
| if aug is None: |
| aug = "ada" |
| else: |
| assert isinstance(aug, str) |
| desc += f"-{aug}" |
|
|
| if aug == "ada": |
| args.ada_target = 0.6 |
|
|
| elif aug == "noaug": |
| pass |
|
|
| elif aug == "fixed": |
| if p is None: |
| raise UserError(f"--aug={aug} requires specifying --p") |
|
|
| else: |
| raise UserError(f"--aug={aug} not supported") |
|
|
| if p is not None: |
| assert isinstance(p, float) |
| if aug != "fixed": |
| raise UserError("--p can only be specified with --aug=fixed") |
| if not 0 <= p <= 1: |
| raise UserError("--p must be between 0 and 1") |
| desc += f"-p{p:g}" |
| args.augment_p = p |
|
|
| if target is not None: |
| assert isinstance(target, float) |
| if aug != "ada": |
| raise UserError("--target can only be specified with --aug=ada") |
| if not 0 <= target <= 1: |
| raise UserError("--target must be between 0 and 1") |
| desc += f"-target{target:g}" |
| args.ada_target = target |
|
|
| assert augpipe is None or isinstance(augpipe, str) |
| if augpipe is None: |
| augpipe = "bgc" |
| else: |
| if aug == "noaug": |
| raise UserError("--augpipe cannot be specified with --aug=noaug") |
| desc += f"-{augpipe}" |
|
|
| augpipe_specs = { |
| "blit": dict(xflip=1, rotate90=1, xint=1), |
| "geom": dict(scale=1, rotate=1, aniso=1, xfrac=1), |
| "color": dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), |
| "filter": dict(imgfilter=1), |
| "noise": dict(noise=1), |
| "cutout": dict(cutout=1), |
| "bg": dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), |
| "bgc": dict( |
| xflip=1, |
| rotate90=1, |
| xint=1, |
| scale=1, |
| rotate=1, |
| aniso=1, |
| xfrac=1, |
| brightness=1, |
| contrast=1, |
| lumaflip=1, |
| hue=1, |
| saturation=1, |
| ), |
| "bgcf": dict( |
| xflip=1, |
| rotate90=1, |
| xint=1, |
| scale=1, |
| rotate=1, |
| aniso=1, |
| xfrac=1, |
| brightness=1, |
| contrast=1, |
| lumaflip=1, |
| hue=1, |
| saturation=1, |
| imgfilter=1, |
| ), |
| "bgcfn": dict( |
| xflip=1, |
| rotate90=1, |
| xint=1, |
| scale=1, |
| rotate=1, |
| aniso=1, |
| xfrac=1, |
| brightness=1, |
| contrast=1, |
| lumaflip=1, |
| hue=1, |
| saturation=1, |
| imgfilter=1, |
| noise=1, |
| ), |
| "bgcfnc": dict( |
| xflip=1, |
| rotate90=1, |
| xint=1, |
| scale=1, |
| rotate=1, |
| aniso=1, |
| xfrac=1, |
| brightness=1, |
| contrast=1, |
| lumaflip=1, |
| hue=1, |
| saturation=1, |
| imgfilter=1, |
| noise=1, |
| cutout=1, |
| ), |
| } |
|
|
| assert augpipe in augpipe_specs |
| if aug != "noaug": |
| args.augment_kwargs = dnnlib.EasyDict( |
| class_name="training.augment.AugmentPipe", **augpipe_specs[augpipe] |
| ) |
|
|
| |
| |
| |
|
|
| resume_specs = { |
| "ffhq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl", |
| "ffhq512": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl", |
| "ffhq1024": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl", |
| "celebahq256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl", |
| "lsundog256": "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl", |
| } |
|
|
| assert resume is None or isinstance(resume, str) |
| if resume is None: |
| resume = "noresume" |
| elif resume == "noresume": |
| desc += "-noresume" |
| elif resume in resume_specs: |
| desc += f"-resume{resume}" |
| args.resume_pkl = resume_specs[resume] |
| else: |
| desc += "-resumecustom" |
| args.resume_pkl = resume |
|
|
| if resume != "noresume": |
| args.ada_kimg = 100 |
| args.ema_rampup = None |
|
|
| if freezed is not None: |
| assert isinstance(freezed, int) |
| if not freezed >= 0: |
| raise UserError("--freezed must be non-negative") |
| desc += f"-freezed{freezed:d}" |
| args.D_kwargs.block_kwargs.freeze_layers = freezed |
|
|
| |
| |
| |
|
|
| if fp32 is None: |
| fp32 = False |
| assert isinstance(fp32, bool) |
| if fp32: |
| args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 |
| args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None |
|
|
| if nhwc is None: |
| nhwc = False |
| assert isinstance(nhwc, bool) |
| if nhwc: |
| args.G_kwargs.synthesis_kwargs.fp16_channels_last = ( |
| args.D_kwargs.block_kwargs.fp16_channels_last |
| ) = True |
|
|
| if nobench is None: |
| nobench = False |
| assert isinstance(nobench, bool) |
| if nobench: |
| args.cudnn_benchmark = False |
|
|
| if allow_tf32 is None: |
| allow_tf32 = False |
| assert isinstance(allow_tf32, bool) |
| if allow_tf32: |
| args.allow_tf32 = True |
|
|
| if workers is not None: |
| assert isinstance(workers, int) |
| if not workers >= 1: |
| raise UserError("--workers must be at least 1") |
| args.data_loader_kwargs.num_workers = workers |
|
|
| return desc, args |
|
|
|
|
| |
|
|
|
|
| def subprocess_fn(rank, args, world_size=1, dist_url="", temp_dir="", slurm=False): |
| dnnlib.util.Logger( |
| file_name=os.path.join(args.run_dir, "log.txt"), |
| file_mode="a", |
| should_flush=True, |
| ) |
|
|
| |
| if not slurm and args.num_gpus > 1: |
| init_file = os.path.abspath(os.path.join(temp_dir, ".torch_distributed_init")) |
| if os.name == "nt": |
| init_method = "file:///" + init_file.replace("\\", "/") |
| torch.distributed.init_process_group( |
| backend="gloo", |
| init_method=init_method, |
| rank=rank, |
| world_size=args.num_gpus, |
| ) |
| else: |
| init_method = f"file://{init_file}" |
| torch.distributed.init_process_group( |
| backend="nccl", |
| init_method=init_method, |
| rank=rank, |
| world_size=args.num_gpus, |
| ) |
| |
| sync_device = torch.device("cuda", rank) if args.num_gpus > 1 else None |
| training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) |
| local_rank = rank |
|
|
| elif slurm: |
| rank = int(os.environ.get("SLURM_PROCID")) |
| local_rank = int(os.environ.get("SLURM_LOCALID")) |
| torch.distributed.init_process_group( |
| backend="nccl", init_method=dist_url, rank=rank, world_size=world_size |
| ) |
| else: |
| rank = local_rank = 0 |
|
|
| if rank != 0: |
| custom_ops.verbosity = "none" |
|
|
| |
| training_loop.training_loop( |
| rank=rank, local_rank=local_rank, temp_dir=temp_dir, **args |
| ) |
|
|
|
|
| |
|
|
|
|
| class CommaSeparatedList(click.ParamType): |
| name = "list" |
|
|
| def convert(self, value, param, ctx): |
| _ = param, ctx |
| if value is None or value.lower() == "none" or value == "": |
| return [] |
| return value.split(",") |
|
|
|
|
| |
|
|
|
|
| def main(args, outdir, master_node="", port=40000, dry_run=False, **config_kwargs): |
| """Train a GAN using the techniques described in the paper |
| "Training Generative Adversarial Networks with Limited Data". |
| |
| Examples: |
| |
| \b |
| # Train with custom dataset using 1 GPU. |
| python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 |
| |
| \b |
| # Train class-conditional CIFAR-10 using 2 GPUs. |
| python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\ |
| --gpus=2 --cfg=cifar --cond=1 |
| |
| \b |
| # Transfer learn MetFaces from FFHQ using 4 GPUs. |
| python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\ |
| --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 |
| |
| \b |
| # Reproduce original StyleGAN2 config F. |
| python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\ |
| --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug |
| |
| \b |
| Base configs (--cfg): |
| auto Automatically select reasonable defaults based on resolution |
| and GPU count. Good starting point for new datasets. |
| stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. |
| paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. |
| paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. |
| paper1024 Reproduce results for MetFaces at 1024x1024. |
| cifar Reproduce results for CIFAR-10 at 32x32. |
| |
| \b |
| Transfer learning source networks (--resume): |
| ffhq256 FFHQ trained at 256x256 resolution. |
| ffhq512 FFHQ trained at 512x512 resolution. |
| ffhq1024 FFHQ trained at 1024x1024 resolution. |
| celebahq256 CelebA-HQ trained at 256x256 resolution. |
| lsundog256 LSUN Dog trained at 256x256 resolution. |
| <PATH or URL> Custom network pickle. |
| """ |
| dnnlib.util.Logger(should_flush=True) |
|
|
| |
| config_kwargs = vars(args) |
| run_desc, args = setup_training_loop_kwargs(**config_kwargs) |
| args.metrics = ["fid50k_full"] |
|
|
| if args.exp_name is None: |
| |
| prev_run_dirs = [] |
| if os.path.isdir(outdir): |
| prev_run_dirs = [ |
| x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x)) |
| ] |
| prev_run_ids = [re.match(r"^\d+", x) for x in prev_run_dirs] |
| prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] |
| cur_run_id = max(prev_run_ids, default=-1) + 1 |
| args.run_dir = os.path.join(outdir, f"{cur_run_id:05d}-{run_desc}") |
| assert not os.path.exists(args.run_dir) |
| else: |
| args.run_dir = os.path.join(outdir, args.exp_name) |
|
|
| |
| print() |
| print("Training options:") |
| |
| print() |
| print(f"Output directory: {args.run_dir}") |
| print(f"Training data: {args.training_set_kwargs.root}") |
| print(f"Training duration: {args.total_kimg} kimg") |
| print(f"Number of GPUs: {args.num_gpus}") |
| print(f"Number of images: {args.training_set_kwargs.max_size}") |
| print(f"Image resolution: {args.training_set_kwargs.resolution}") |
| print(f"Conditional model: {args.training_set_kwargs.load_labels}") |
| print(f"Dataset x-flips: {args.training_set_kwargs.xflip}") |
| print() |
|
|
| |
| if dry_run: |
| print("Dry run; exiting.") |
| return |
|
|
| |
| print("Creating output directory...") |
| if not os.path.exists(args.run_dir): |
| os.makedirs(args.run_dir, exist_ok=True) |
| with open(os.path.join(args.run_dir, "training_options.json"), "wt") as f: |
| json.dump(args, f, indent=2) |
|
|
| |
| if args.slurm: |
| n_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES")) |
| n_gpus_per_node = int(os.environ.get("SLURM_TASKS_PER_NODE").split("(")[0]) |
| world_size = n_gpus_per_node * n_nodes |
| dist_url = "tcp://" |
| dist_url += master_node |
| dist_url += ":" + str(port) |
| print("Dist url ", dist_url) |
| temp_dir = "/scratch/slurm_tmpdir/" + str(os.environ.get("SLURM_JOB_ID")) |
| subprocess_fn( |
| rank=-1, |
| args=args, |
| world_size=world_size, |
| dist_url=dist_url, |
| temp_dir=temp_dir, |
| slurm=args.slurm, |
| ) |
| else: |
| |
| print("Launching processes...") |
| torch.multiprocessing.set_start_method("spawn") |
| with tempfile.TemporaryDirectory() as temp_dir: |
| if args.num_gpus == 1: |
| subprocess_fn(rank=0, args=args, temp_dir=temp_dir) |
| else: |
| torch.multiprocessing.spawn( |
| fn=subprocess_fn, |
| args=(args, args.num_gpus, "", temp_dir), |
| nprocs=args.num_gpus, |
| ) |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| parser_ = parser.get_parser() |
| args = parser_.parse_args() |
| main(args) |
|
|
| |
|
|