| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import pathlib |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| import torch.nn as nn |
| from apex.optimizers import FusedAdam, FusedLAMB |
| from torch.nn.modules.loss import _Loss |
| from torch.nn.parallel import DistributedDataParallel |
| from torch.optim import Optimizer |
| from torch.utils.data import DataLoader, DistributedSampler |
| from tqdm import tqdm |
|
|
| from se3_transformer.data_loading import QM9DataModule |
| from se3_transformer.model import SE3TransformerPooled |
| from se3_transformer.model.fiber import Fiber |
| from se3_transformer.runtime import gpu_affinity |
| from se3_transformer.runtime.arguments import PARSER |
| from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \ |
| PerformanceCallback |
| from se3_transformer.runtime.inference import evaluate |
| from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger |
| from se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \ |
| using_tensor_cores, increase_l2_fetch_granularity |
|
|
|
|
| def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): |
| """ Saves model, optimizer and epoch states to path (only once per node) """ |
| if get_local_rank() == 0: |
| state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict() |
| checkpoint = { |
| 'state_dict': state_dict, |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epoch |
| } |
| for callback in callbacks: |
| callback.on_checkpoint_save(checkpoint) |
|
|
| torch.save(checkpoint, str(path)) |
| logging.info(f'Saved checkpoint to {str(path)}') |
|
|
|
|
| def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]): |
| """ Loads model, optimizer and epoch states from path """ |
| checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'}) |
| if isinstance(model, DistributedDataParallel): |
| model.module.load_state_dict(checkpoint['state_dict']) |
| else: |
| model.load_state_dict(checkpoint['state_dict']) |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
| for callback in callbacks: |
| callback.on_checkpoint_load(checkpoint) |
|
|
| logging.info(f'Loaded checkpoint from {str(path)}') |
| return checkpoint['epoch'] |
|
|
|
|
| def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args): |
| losses = [] |
| for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch', |
| desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)): |
| *inputs, target = to_cuda(batch) |
|
|
| for callback in callbacks: |
| callback.on_batch_start() |
|
|
| with torch.cuda.amp.autocast(enabled=args.amp): |
| pred = model(*inputs) |
| loss = loss_fn(pred, target) / args.accumulate_grad_batches |
|
|
| grad_scaler.scale(loss).backward() |
|
|
| |
| if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader): |
| if args.gradient_clip: |
| grad_scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip) |
|
|
| grad_scaler.step(optimizer) |
| grad_scaler.update() |
| optimizer.zero_grad() |
|
|
| losses.append(loss.item()) |
|
|
| return np.mean(losses) |
|
|
|
|
| def train(model: nn.Module, |
| loss_fn: _Loss, |
| train_dataloader: DataLoader, |
| val_dataloader: DataLoader, |
| callbacks: List[BaseCallback], |
| logger: Logger, |
| args): |
| device = torch.cuda.current_device() |
| model.to(device=device) |
| local_rank = get_local_rank() |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
| if dist.is_initialized(): |
| model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) |
|
|
| model.train() |
| grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) |
| if args.optimizer == 'adam': |
| optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), |
| weight_decay=args.weight_decay) |
| elif args.optimizer == 'lamb': |
| optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), |
| weight_decay=args.weight_decay) |
| else: |
| optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, |
| weight_decay=args.weight_decay) |
|
|
| epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 |
|
|
| for callback in callbacks: |
| callback.on_fit_start(optimizer, args) |
|
|
| for epoch_idx in range(epoch_start, args.epochs): |
| if isinstance(train_dataloader.sampler, DistributedSampler): |
| train_dataloader.sampler.set_epoch(epoch_idx) |
|
|
| loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) |
| if dist.is_initialized(): |
| loss = torch.tensor(loss, dtype=torch.float, device=device) |
| torch.distributed.all_reduce(loss) |
| loss = (loss / world_size).item() |
|
|
| logging.info(f'Train loss: {loss}') |
| logger.log_metrics({'train loss': loss}, epoch_idx) |
|
|
| for callback in callbacks: |
| callback.on_epoch_end() |
|
|
| if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ |
| and (epoch_idx + 1) % args.ckpt_interval == 0: |
| save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) |
|
|
| if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0: |
| evaluate(model, val_dataloader, callbacks, args) |
| model.train() |
|
|
| for callback in callbacks: |
| callback.on_validation_end(epoch_idx) |
|
|
| if args.save_ckpt_path is not None and not args.benchmark: |
| save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) |
|
|
| for callback in callbacks: |
| callback.on_fit_end() |
|
|
|
|
| def print_parameters_count(model): |
| num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logging.info(f'Number of trainable parameters: {num_params_trainable}') |
|
|
|
|
| if __name__ == '__main__': |
| is_distributed = init_distributed() |
| local_rank = get_local_rank() |
| args = PARSER.parse_args() |
|
|
| logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) |
|
|
| logging.info('====== SE(3)-Transformer ======') |
| logging.info('| Training procedure |') |
| logging.info('===============================') |
|
|
| if args.seed is not None: |
| logging.info(f'Using seed {args.seed}') |
| seed_everything(args.seed) |
|
|
| logger = LoggerCollection([ |
| DLLogger(save_dir=args.log_dir, filename=args.dllogger_name), |
| WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer') |
| ]) |
|
|
| datamodule = QM9DataModule(**vars(args)) |
| model = SE3TransformerPooled( |
| fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), |
| fiber_out=Fiber({0: args.num_degrees * args.num_channels}), |
| fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), |
| output_dim=1, |
| tensor_cores=using_tensor_cores(args.amp), |
| **vars(args) |
| ) |
| loss_fn = nn.L1Loss() |
|
|
| if args.benchmark: |
| logging.info('Running benchmark mode') |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 |
| callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] |
| else: |
| callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'), |
| QM9LRSchedulerCallback(logger, epochs=args.epochs)] |
|
|
| if is_distributed: |
| gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count()) |
|
|
| print_parameters_count(model) |
| logger.log_hyperparams(vars(args)) |
| increase_l2_fetch_granularity() |
| train(model, |
| loss_fn, |
| datamodule.train_dataloader(), |
| datamodule.val_dataloader(), |
| callbacks, |
| logger, |
| args) |
|
|
| logging.info('Training finished successfully') |
|
|