| import time, os |
|
|
| start = time.time() |
| import torch |
| from torch.nn.parallel import DistributedDataParallel |
| from dgl.dataloading import GraphDataLoader |
| from torch.amp import GradScaler |
| import numpy as np |
| import hydra |
| from omegaconf import DictConfig |
| from physicsnemo.launch.logging import ( |
| PythonLogger, |
| RankZeroLoggingWrapper, |
| ) |
| from physicsnemo.launch.utils import load_checkpoint, save_checkpoint |
| from physicsnemo.distributed.manager import DistributedManager |
|
|
| import json |
| from tqdm import tqdm |
| import random |
|
|
| import models.MeshGraphNet as MeshGraphNet |
| from dataset.Dataset import get_dataset |
| import metrics |
|
|
| import utils |
|
|
| class MGNTrainer: |
| def __init__(self, logger, cfg, dist): |
| |
| self.device = dist.device |
| logger.info(f"Using {self.device} device") |
|
|
| start = time.time() |
| self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device) |
| print(f"total time loading dataset: {time.time() - start:.2f} seconds") |
|
|
| dtype_str = getattr(cfg.root_dataset, "dtype", "torch.float32") |
| if isinstance(dtype_str, str) and dtype_str.startswith("torch."): |
| self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32) |
| else: |
| self.dtype = torch.float32 |
|
|
| self.model = utils.build_from_module(cfg.architecture) |
| self.model = self.model.to(dtype=self.dtype, device=self.device) |
| |
| |
|
|
| if cfg.performance.jit: |
| self.model = torch.jit.script(self.model).to(self.device) |
| else: |
| self.model = self.model.to(self.device) |
|
|
| |
| self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.scheduler.lr) |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| self.optimizer, |
| T_max=cfg.training.epochs, |
| eta_min=cfg.scheduler.lr * cfg.scheduler.lr_decay, |
| ) |
| self.scaler = GradScaler('cuda') |
|
|
| |
| self.epoch_init = load_checkpoint( |
| os.path.join(cfg.checkpoints.ckpt_path, cfg.checkpoints.ckpt_name), |
| models=self.model, |
| optimizer=self.optimizer, |
| scheduler=self.scheduler, |
| scaler=self.scaler, |
| device=self.device, |
| ) |
|
|
| self.cfg = cfg |
|
|
| def backward(self, loss): |
| """ |
| Perform backward pass. |
| |
| Arguments: |
| loss: loss value. |
| |
| """ |
| |
| if self.cfg.performance.amp: |
| self.scaler.scale(loss).backward() |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| loss.backward() |
| self.optimizer.step() |
|
|
| def train(self, graph, metadata): |
| """ |
| Perform one training iteration over one graph. The training is performed |
| over multiple timesteps, where the number of timesteps is specified in |
| the 'stride' parameter. |
| |
| Arguments: |
| graph: the desired graph. |
| |
| Returns: |
| loss: loss value. |
| |
| """ |
| graph = graph.to(self.device, non_blocking=True) |
| globals = metadata['globals'].to(self.device, non_blocking=True) |
| label = metadata['label'].to(self.device, non_blocking=True) |
| weight = metadata['weight'].to(self.device, non_blocking=True) |
|
|
| self.optimizer.zero_grad() |
| pred = self.model(graph.ndata["features"], graph.edata["features"], globals, graph, metadata) |
| loss = metrics.weighted_bce(pred, label, weights=weight) |
| self.backward(loss) |
| return loss.detach() |
|
|
| @torch.no_grad() |
| def eval(self): |
| """ |
| Evaluate the model on one batch. |
| |
| Args: |
| graph (DGLGraph): The input graph. |
| label (Tensor): The target labels. |
| |
| Returns: |
| loss (Tensor): The computed loss value (scalar). |
| """ |
| predictions = [] |
| labels = [] |
| weights = [] |
|
|
| for graph, metadata in self.valloader: |
| |
| graph = graph.to(self.device, non_blocking=True) |
| globals = metadata['globals'].to(self.device, non_blocking=True) |
| label = metadata['label'].to(self.device, non_blocking=True) |
| weight = metadata['weight'].to(self.device, non_blocking=True) |
| |
| pred = self.model(graph.ndata["features"], graph.edata["features"], globals, graph, metadata) |
| predictions.append(pred) |
| labels.append(label) |
| weights.append(weight) |
|
|
| predictions = torch.cat(predictions, dim=0) |
| labels = torch.cat(labels, dim=0) |
| weights = torch.cat(weights, dim=0) |
|
|
| loss = metrics.weighted_bce(predictions, labels, weights=weights) |
| |
| |
| prob = torch.sigmoid(predictions) |
|
|
| |
| prob_flat = prob.detach().to(torch.float32).cpu().numpy().flatten() |
| labels_flat = labels.detach().to(torch.float32).cpu().numpy().flatten() |
|
|
| |
| try: |
| auc = metrics.roc_auc_score(labels_flat, prob_flat) |
| except ValueError: |
| auc = float('nan') |
|
|
| return loss, auc |
|
|
| @hydra.main(version_base=None, config_path="./configs/", config_name="tHjb_CP_0_vs_45") |
| def do_training(cfg: DictConfig): |
| """ |
| Perform training over all graphs in the dataset. |
| |
| Arguments: |
| cfg: Dictionary of parameters. |
| |
| """ |
| random.seed(cfg.random_seed) |
| np.random.seed(cfg.random_seed) |
| torch.manual_seed(cfg.random_seed) |
|
|
| |
| DistributedManager.initialize() |
| dist = DistributedManager() |
|
|
| |
| os.makedirs(cfg.checkpoints.ckpt_path, exist_ok=True) |
| logger = PythonLogger("main") |
| logger.file_logging(os.path.join(cfg.checkpoints.ckpt_path, "train.log")) |
|
|
| |
| trainer = MGNTrainer(logger, cfg, dist) |
|
|
| if dist.distributed: |
| ddps = torch.cuda.Stream() |
| with torch.cuda.stream(ddps): |
| trainer.model = DistributedDataParallel( |
| trainer.model, |
| device_ids=[dist.local_rank], |
| |
| |
| output_device=dist.device, |
| broadcast_buffers=dist.broadcast_buffers, |
| find_unused_parameters=dist.find_unused_parameters, |
| ) |
| torch.cuda.current_stream().wait_stream(ddps) |
|
|
| |
| start = time.time() |
| logger.info("Training started...") |
| for epoch in range(trainer.epoch_init, cfg.training.epochs): |
|
|
| |
| train_loss = [] |
| for graph, metadata in tqdm(trainer.trainloader, desc=f"epoch {epoch} trianing"): |
| trainer.model.train() |
| loss = trainer.train(graph, metadata) |
| train_loss.append(loss.item()) |
| |
| val_loss, val_auc = trainer.eval() |
|
|
| train_loss = torch.tensor(train_loss).mean() |
|
|
| logger.info( |
| f"epoch: {epoch}, loss: {train_loss:10.3e}, val_loss: {val_loss:10.3e}, val_auc = {val_auc:10.3e}, time per epoch: {(time.time()-start):10.3e}" |
| ) |
|
|
| |
| save_checkpoint( |
| os.path.join(cfg.checkpoints.ckpt_path, cfg.checkpoints.ckpt_name), |
| models=trainer.model, |
| optimizer=trainer.optimizer, |
| scheduler=trainer.scheduler, |
| scaler=trainer.scaler, |
| epoch=epoch, |
| ) |
| start = time.time() |
| trainer.scheduler.step() |
| logger.info("Training completed!") |
|
|
|
|
| """ |
| Perform training over all graphs in the dataset. |
| |
| Arguments: |
| cfg: Dictionary of parameters. |
| |
| """ |
| if __name__ == "__main__": |
| do_training() |