Spaces:
Sleeping
Sleeping
File size: 11,493 Bytes
df60d6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
"""Main training loop."""
import os
import time
import copy
import json
import pickle
import psutil
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
#----------------------------------------------------------------------------
def training_loop(
run_dir = '.', # Output directory.
dataset_kwargs = {}, # Options for training set.
data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader.
network_kwargs = {}, # Options for model and preconditioning.
loss_kwargs = {}, # Options for loss function.
optimizer_kwargs = {}, # Options for optimizer.
augment_kwargs = None, # Options for augmentation pipeline, None = disable.
seed = 0, # Global random seed.
batch_size = 512, # Total batch size for one training iteration.
batch_gpu = None, # Limit batch size per GPU, None = no limit.
total_kimg = 200000, # Training duration, measured in thousands of training images.
ema_halflife_kimg = 500, # Half-life of the exponential moving average (EMA) of model weights.
ema_rampup_ratio = 0.05, # EMA ramp-up coefficient, None = no rampup.
lr_rampup_kimg = 10000, # Learning rate ramp-up duration.
loss_scaling = 1, # Loss scaling factor for reducing FP16 under/overflows.
kimg_per_tick = 50, # Interval of progress prints.
snapshot_ticks = 50, # How often to save network snapshots, None = disable.
state_dump_ticks = 500, # How often to dump training state, None = disable.
resume_pkl = None, # Start from the given network snapshot, None = random initialization.
resume_state_dump = None, # Start from the given training state, None = reset training state.
resume_kimg = 0, # Start from the given training progress.
cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark?
device = torch.device('cuda'),
):
# Initialize.
start_time = time.time()
np.random.seed((seed * dist.get_world_size() + dist.get_rank()) % (1 << 31))
torch.manual_seed(np.random.randint(1 << 31))
torch.backends.cudnn.benchmark = cudnn_benchmark
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
# Select batch size per GPU.
batch_gpu_total = batch_size // dist.get_world_size()
if batch_gpu is None or batch_gpu > batch_gpu_total:
batch_gpu = batch_gpu_total
num_accumulation_rounds = batch_gpu_total // batch_gpu
assert batch_size == batch_gpu * num_accumulation_rounds * dist.get_world_size()
# Load dataset.
dist.print0('Loading dataset...')
dataset_obj = dnnlib.util.construct_class_by_name(**dataset_kwargs) # subclass of training.dataset.Dataset
dataset_sampler = misc.InfiniteSampler(dataset=dataset_obj, rank=dist.get_rank(), num_replicas=dist.get_world_size(), seed=seed)
dataset_iterator = iter(torch.utils.data.DataLoader(dataset=dataset_obj, sampler=dataset_sampler, batch_size=batch_gpu, **data_loader_kwargs))
# Construct network.
dist.print0('Constructing network...')
interface_kwargs = dict(img_resolution=dataset_obj.resolution, img_channels=dataset_obj.num_channels, label_dim=dataset_obj.label_dim)
net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
net.train().requires_grad_(True).to(device)
if dist.get_rank() == 0:
with torch.no_grad():
images = torch.zeros([batch_gpu, net.img_channels, net.img_resolution, net.img_resolution], device=device)
sigma = torch.ones([batch_gpu], device=device)
labels = torch.zeros([batch_gpu, net.label_dim], device=device)
misc.print_module_summary(net, [images, sigma, labels], max_nesting=2)
# Setup optimizer.
dist.print0('Setting up optimizer...')
loss_fn = dnnlib.util.construct_class_by_name(**loss_kwargs) # training.loss.(VP|VE|EDM)Loss
optimizer = dnnlib.util.construct_class_by_name(params=net.parameters(), **optimizer_kwargs) # subclass of torch.optim.Optimizer
augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs) if augment_kwargs is not None else None # training.augment.AugmentPipe
ddp = torch.nn.parallel.DistributedDataParallel(net, device_ids=[device], broadcast_buffers=False)
ema = copy.deepcopy(net).eval().requires_grad_(False)
# Resume training from previous snapshot.
if resume_pkl is not None:
dist.print0(f'Loading network weights from "{resume_pkl}"...')
if dist.get_rank() != 0:
torch.distributed.barrier() # rank 0 goes first
with dnnlib.util.open_url(resume_pkl, verbose=(dist.get_rank() == 0)) as f:
data = pickle.load(f)
if dist.get_rank() == 0:
torch.distributed.barrier() # other ranks follow
misc.copy_params_and_buffers(src_module=data['ema'], dst_module=net, require_all=False)
misc.copy_params_and_buffers(src_module=data['ema'], dst_module=ema, require_all=False)
del data # conserve memory
if resume_state_dump:
dist.print0(f'Loading training state from "{resume_state_dump}"...')
data = torch.load(resume_state_dump, map_location=torch.device('cpu'))
misc.copy_params_and_buffers(src_module=data['net'], dst_module=net, require_all=True)
optimizer.load_state_dict(data['optimizer_state'])
del data # conserve memory
# Train.
dist.print0(f'Training for {total_kimg} kimg...')
dist.print0()
cur_nimg = resume_kimg * 1000
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
dist.update_progress(cur_nimg // 1000, total_kimg)
stats_jsonl = None
while True:
# Accumulate gradients.
optimizer.zero_grad(set_to_none=True)
for round_idx in range(num_accumulation_rounds):
with misc.ddp_sync(ddp, (round_idx == num_accumulation_rounds - 1)):
images, labels = next(dataset_iterator)
images = images.to(device).to(torch.float32) / 127.5 - 1
labels = labels.to(device)
loss = loss_fn(net=ddp, images=images, labels=labels, augment_pipe=augment_pipe)
training_stats.report('Loss/loss', loss)
loss.sum().mul(loss_scaling / batch_gpu_total).backward()
# Update weights.
for g in optimizer.param_groups:
g['lr'] = optimizer_kwargs['lr'] * min(cur_nimg / max(lr_rampup_kimg * 1000, 1e-8), 1)
for param in net.parameters():
if param.grad is not None:
torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
optimizer.step()
# Update EMA.
ema_halflife_nimg = ema_halflife_kimg * 1000
if ema_rampup_ratio is not None:
ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
for p_ema, p_net in zip(ema.parameters(), net.parameters()):
p_ema.copy_(p_net.detach().lerp(p_ema, ema_beta))
# Perform maintenance tasks once per tick.
cur_nimg += batch_size
done = (cur_nimg >= total_kimg * 1000)
if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000):
continue
# Print status line, accumulating the same information in training_stats.
tick_end_time = time.time()
fields = []
fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"]
fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"]
fields += [f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2**30):<6.2f}"]
torch.cuda.reset_peak_memory_stats()
dist.print0(' '.join(fields))
# Check for abort.
if (not done) and dist.should_stop():
done = True
dist.print0()
dist.print0('Aborting...')
# Save network snapshot.
if (snapshot_ticks is not None) and (done or cur_tick % snapshot_ticks == 0):
data = dict(ema=ema, loss_fn=loss_fn, augment_pipe=augment_pipe, dataset_kwargs=dict(dataset_kwargs))
for key, value in data.items():
if isinstance(value, torch.nn.Module):
value = copy.deepcopy(value).eval().requires_grad_(False)
misc.check_ddp_consistency(value)
data[key] = value.cpu()
del value # conserve memory
if dist.get_rank() == 0:
with open(os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl'), 'wb') as f:
pickle.dump(data, f)
del data # conserve memory
# Save full dump of the training state.
if (state_dump_ticks is not None) and (done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
torch.save(dict(net=net, optimizer_state=optimizer.state_dict()), os.path.join(run_dir, f'training-state-{cur_nimg//1000:06d}.pt'))
# Update logs.
training_stats.default_collector.update()
if dist.get_rank() == 0:
if stats_jsonl is None:
stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'at')
stats_jsonl.write(json.dumps(dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n')
stats_jsonl.flush()
dist.update_progress(cur_nimg // 1000, total_kimg)
# Update state.
cur_tick += 1
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - tick_end_time
if done:
break
# Done.
dist.print0()
dist.print0('Exiting...')
#----------------------------------------------------------------------------
|