ho22joshua's picture
adding edge network
d646e7f
import time, os
start = time.time()
import torch
from torch.nn.parallel import DistributedDataParallel
from dgl.dataloading import GraphDataLoader
from torch.amp import GradScaler
import numpy as np
import hydra
from omegaconf import DictConfig
from physicsnemo.launch.logging import (
PythonLogger,
RankZeroLoggingWrapper,
)
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
from physicsnemo.distributed.manager import DistributedManager
import json
from tqdm import tqdm
import random
import models.MeshGraphNet as MeshGraphNet
from dataset.Dataset import get_dataset
import metrics
import utils
class MGNTrainer:
def __init__(self, logger, cfg, dist):
# set device
self.device = dist.device
logger.info(f"Using {self.device} device")
start = time.time()
self.trainloader, self.valloader, self.testloader = get_dataset(cfg, self.device)
print(f"total time loading dataset: {time.time() - start:.2f} seconds")
dtype_str = getattr(cfg.root_dataset, "dtype", "torch.float32")
if isinstance(dtype_str, str) and dtype_str.startswith("torch."):
self.dtype = getattr(torch, dtype_str.split(".")[-1], torch.float32)
else:
self.dtype = torch.float32
self.model = utils.build_from_module(cfg.architecture)
self.model = self.model.to(dtype=self.dtype, device=self.device)
# num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
# print(f"Number of trainable parameters: {num_params}")
if cfg.performance.jit:
self.model = torch.jit.script(self.model).to(self.device)
else:
self.model = self.model.to(self.device)
# instantiate loss, optimizer, and scheduler
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.scheduler.lr)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=cfg.training.epochs,
eta_min=cfg.scheduler.lr * cfg.scheduler.lr_decay,
)
self.scaler = GradScaler('cuda')
# load checkpoint
self.epoch_init = load_checkpoint(
os.path.join(cfg.checkpoints.ckpt_path, cfg.checkpoints.ckpt_name),
models=self.model,
optimizer=self.optimizer,
scheduler=self.scheduler,
scaler=self.scaler,
device=self.device,
)
self.cfg = cfg
def backward(self, loss):
"""
Perform backward pass.
Arguments:
loss: loss value.
"""
# backward pass
if self.cfg.performance.amp:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
else:
loss.backward()
self.optimizer.step()
def train(self, graph, metadata):
"""
Perform one training iteration over one graph. The training is performed
over multiple timesteps, where the number of timesteps is specified in
the 'stride' parameter.
Arguments:
graph: the desired graph.
Returns:
loss: loss value.
"""
graph = graph.to(self.device, non_blocking=True)
globals = metadata['globals'].to(self.device, non_blocking=True)
label = metadata['label'].to(self.device, non_blocking=True)
weight = metadata['weight'].to(self.device, non_blocking=True)
self.optimizer.zero_grad()
pred = self.model(graph.ndata["features"], graph.edata["features"], globals, graph, metadata)
loss = metrics.weighted_bce(pred, label, weights=weight)
self.backward(loss)
return loss.detach()
@torch.no_grad()
def eval(self):
"""
Evaluate the model on one batch.
Args:
graph (DGLGraph): The input graph.
label (Tensor): The target labels.
Returns:
loss (Tensor): The computed loss value (scalar).
"""
predictions = []
labels = []
weights = []
for graph, metadata in self.valloader:
graph = graph.to(self.device, non_blocking=True)
globals = metadata['globals'].to(self.device, non_blocking=True)
label = metadata['label'].to(self.device, non_blocking=True)
weight = metadata['weight'].to(self.device, non_blocking=True)
pred = self.model(graph.ndata["features"], graph.edata["features"], globals, graph, metadata)
predictions.append(pred)
labels.append(label)
weights.append(weight)
predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)
weights = torch.cat(weights, dim=0)
loss = metrics.weighted_bce(predictions, labels, weights=weights)
# Convert logits to probabilities
prob = torch.sigmoid(predictions)
# Flatten to 1D arrays
prob_flat = prob.detach().to(torch.float32).cpu().numpy().flatten()
labels_flat = labels.detach().to(torch.float32).cpu().numpy().flatten()
# Calculate AUC
try:
auc = metrics.roc_auc_score(labels_flat, prob_flat)
except ValueError:
auc = float('nan') # Not enough classes present for AUC
return loss, auc
@hydra.main(version_base=None, config_path="./configs/", config_name="tHjb_CP_0_vs_45")
def do_training(cfg: DictConfig):
"""
Perform training over all graphs in the dataset.
Arguments:
cfg: Dictionary of parameters.
"""
random.seed(cfg.random_seed)
np.random.seed(cfg.random_seed)
torch.manual_seed(cfg.random_seed)
# initialize distributed manager
DistributedManager.initialize()
dist = DistributedManager()
# initialize loggers
os.makedirs(cfg.checkpoints.ckpt_path, exist_ok=True)
logger = PythonLogger("main")
logger.file_logging(os.path.join(cfg.checkpoints.ckpt_path, "train.log"))
# initialize trainer
trainer = MGNTrainer(logger, cfg, dist)
if dist.distributed:
ddps = torch.cuda.Stream()
with torch.cuda.stream(ddps):
trainer.model = DistributedDataParallel(
trainer.model,
device_ids=[dist.local_rank], # Set the device_id to be
# the local rank of this process on
# this node
output_device=dist.device,
broadcast_buffers=dist.broadcast_buffers,
find_unused_parameters=dist.find_unused_parameters,
)
torch.cuda.current_stream().wait_stream(ddps)
# training loop
start = time.time()
logger.info("Training started...")
for epoch in range(trainer.epoch_init, cfg.training.epochs):
# Training
train_loss = []
for graph, metadata in tqdm(trainer.trainloader, desc=f"epoch {epoch} trianing"):
trainer.model.train()
loss = trainer.train(graph, metadata)
train_loss.append(loss.item())
val_loss, val_auc = trainer.eval()
train_loss = torch.tensor(train_loss).mean()
logger.info(
f"epoch: {epoch}, loss: {train_loss:10.3e}, val_loss: {val_loss:10.3e}, val_auc = {val_auc:10.3e}, time per epoch: {(time.time()-start):10.3e}"
)
# save checkpoint
save_checkpoint(
os.path.join(cfg.checkpoints.ckpt_path, cfg.checkpoints.ckpt_name),
models=trainer.model,
optimizer=trainer.optimizer,
scheduler=trainer.scheduler,
scaler=trainer.scaler,
epoch=epoch,
)
start = time.time()
trainer.scheduler.step()
logger.info("Training completed!")
"""
Perform training over all graphs in the dataset.
Arguments:
cfg: Dictionary of parameters.
"""
if __name__ == "__main__":
do_training()