| | import os, sys |
| | import re |
| | import torch |
| | import argparse |
| | import yaml |
| | import pandas as pd |
| | import numpy as np |
| | from glob import glob |
| | from queue import Queue |
| | from loguru import logger |
| | from threading import Thread |
| | from torch_geometric.data import Data, HeteroData |
| | import torch.distributed as dist |
| | import random |
| | import subprocess |
| | import time |
| | from torch.utils.tensorboard import SummaryWriter |
| | from datetime import datetime |
| |
|
| |
|
| | |
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value""" |
| |
|
| | def __init__(self, length=0): |
| | self.length = length |
| | self.reset() |
| |
|
| | def reset(self): |
| | if self.length > 0: |
| | self.history = [] |
| | else: |
| | self.count = 0 |
| | self.sum = 0.0 |
| | self.val = 0.0 |
| | self.avg = 0.0 |
| |
|
| | def update(self, val, num=1): |
| | if self.length > 0: |
| | |
| | assert num == 1 |
| | self.history.append(val) |
| | if len(self.history) > self.length: |
| | del self.history[0] |
| |
|
| | self.val = self.history[-1] |
| | self.avg = np.mean(self.history) |
| | else: |
| | self.val = val |
| | self.sum += val * num |
| | self.count += num |
| | self.avg = self.sum / self.count |
| |
|
| |
|
| | class AVGMeter(): |
| | def __init__(self): |
| | self.value = 0 |
| | self.cnt = 0 |
| |
|
| | def update(self, v_new): |
| | self.value += v_new |
| | self.cnt += 1 |
| |
|
| | def agg(self): |
| | return self.value / self.cnt |
| |
|
| | def reset(self): |
| | self.value = 0 |
| | self.cnt = 0 |
| |
|
| |
|
| | class Reporter(): |
| | def __init__(self, cfg, log_dir) -> None: |
| | print("="*20, cfg['log_path']) |
| | self.writer = SummaryWriter(log_dir) |
| | self.cfg = cfg |
| |
|
| | def record(self, value_dict, epoch): |
| | for key in value_dict: |
| | if isinstance(value_dict[key], AVGMeter): |
| | self.writer.add_scalar(key, value_dict[key].agg(), epoch) |
| | else: |
| | self.writer.add_scalar(key, value_dict[key], epoch) |
| |
|
| | def close(self): |
| | self.writer.close() |
| |
|
| |
|
| | class Timer: |
| | def __init__(self, rest_epochs): |
| | self.elapsed_time = None |
| | self.rest_epochs = rest_epochs |
| | self.eta = None |
| |
|
| | def __enter__(self): |
| | self.start_time = time.time() |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_value, traceback): |
| | self.elapsed_time = time.time() - self.start_time |
| | |
| | self.eta = round((self.rest_epochs * self.elapsed_time) / 3600, 2) |
| |
|
| |
|
| |
|
| | |
| | def get_argparse(): |
| | str2bool = lambda x: x.lower() == 'true' |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--config', type=str, default='./configs/default.yaml') |
| | parser.add_argument('--distributed', default=False, action='store_true') |
| | parser.add_argument('--local-rank', default=0, type=int, help='node rank for distributed training') |
| | parser.add_argument("--seed", type=int, default=2024) |
| | parser.add_argument("--ngpus", type=int, default=1) |
| | args = parser.parse_args() |
| | return args |
| |
|
| | def count_parameters(model): |
| | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | return total_params / 1_000_000 |
| |
|
| | def model_info(model, verbose=False, img_size=640): |
| | |
| | n_p = sum(x.numel() for x in model.parameters()) |
| | n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) |
| | if verbose: |
| | print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) |
| | for i, (name, p) in enumerate(model.named_parameters()): |
| | name = name.replace('module_list.', '') |
| | print('%5g %40s %9s %12g %20s %10.3g %10.3g' % |
| | (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) |
| |
|
| | try: |
| | from thop import profile |
| | flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, img_size, img_size),), verbose=False)[0] / 1E9 * 2 |
| | img_size = img_size if isinstance(img_size, list) else [img_size, img_size] |
| | fs = ', %.9f GFLOPS' % (flops) |
| | except (ImportError, Exception): |
| | fs = '' |
| |
|
| | logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") |
| |
|
| | def get_cfg(): |
| | args = get_argparse() |
| |
|
| | with open(args.config, 'r') as file: |
| | cfg = yaml.safe_load(file) |
| |
|
| | for key, value in vars(args).items(): |
| | if value is not None: |
| | cfg[key] = value |
| |
|
| | cfg['log_path'] = os.path.join(cfg['log_path'], os.path.basename(args.config)[:-5]) |
| |
|
| | metadata = (cfg['data']['meta']['node'], |
| | list(map(tuple, cfg['data']['meta']['edge']))) |
| | return cfg, metadata |
| |
|
| |
|
| | def init_seeds(seed=0): |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| |
|
| |
|
| | def set_random_seed(seed, deterministic=False): |
| | """Set random seed.""" |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | if deterministic: |
| | torch.backends.cudnn.enabled = True |
| | torch.backends.cudnn.benchmark = False |
| | torch.backends.cudnn.deterministic = True |
| | else: |
| | torch.backends.cudnn.enabled = True |
| | torch.backends.cudnn.benchmark = True |
| |
|
| |
|
| | def get_world_size(): |
| | if not dist.is_available(): |
| | return 1 |
| | if not dist.is_initialized(): |
| | return 1 |
| | return dist.get_world_size() |
| |
|
| |
|
| | def get_rank(): |
| | if not dist.is_available(): |
| | return 0 |
| | if not dist.is_initialized(): |
| | return 0 |
| | return dist.get_rank() |
| |
|
| |
|
| | def is_main_process(): |
| | return get_rank() == 0 |
| |
|
| | |
| |
|
| |
|
| | logs = set() |
| |
|
| |
|
| | def time_str(fmt=None): |
| | if fmt is None: |
| | fmt = '%Y-%m-%d_%H:%M:%S' |
| | return datetime.today().strftime(fmt) |
| |
|
| |
|
| | def setup_default_logging(save_path, flag_multigpus=False, l_level='INFO'): |
| |
|
| | if flag_multigpus: |
| | rank = dist.get_rank() |
| | if rank != 0: |
| | return |
| |
|
| | tmp_timestr = time_str(fmt='%Y_%m_%d_%H_%M_%S') |
| | logger.add( |
| | os.path.join(save_path, f'{tmp_timestr}.log'), |
| | |
| | level=l_level, |
| | |
| | format='{level}|{time:YYYY-MM-DD HH:mm:ss}: {message}', |
| | |
| | |
| | enqueue=True, |
| | encoding='utf-8', |
| | ) |
| | return tmp_timestr |
| |
|
| |
|
| |
|
| | def world_info_from_env(): |
| | local_rank = 0 |
| | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): |
| | if v in os.environ: |
| | local_rank = int(os.environ[v]) |
| | break |
| | global_rank = 0 |
| | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): |
| | if v in os.environ: |
| | global_rank = int(os.environ[v]) |
| | break |
| | world_size = 1 |
| | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): |
| | if v in os.environ: |
| | world_size = int(os.environ[v]) |
| | break |
| |
|
| | return local_rank, global_rank, world_size |
| |
|
| |
|
| | def setup_distributed(backend="nccl", port=None): |
| | """AdaHessian Optimizer |
| | Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py |
| | Originally licensed MIT, Copyright (c) 2020 Wei Li |
| | """ |
| | num_gpus = torch.cuda.device_count() |
| | |
| | if "SLURM_JOB_ID" in os.environ and "ZHENSALLOC" not in os.environ: |
| | _, rank, world_size = world_info_from_env() |
| | node_list = os.environ["SLURM_NODELIST"] |
| | addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") |
| | |
| | if port is not None: |
| | os.environ["MASTER_PORT"] = str(port) |
| | elif "MASTER_PORT" not in os.environ: |
| | os.environ["MASTER_PORT"] = "10685" |
| | if "MASTER_ADDR" not in os.environ: |
| | os.environ["MASTER_ADDR"] = addr |
| | os.environ["WORLD_SIZE"] = str(world_size) |
| | os.environ["LOCAL_RANK"] = str(rank % num_gpus) |
| | os.environ["RANK"] = str(rank) |
| | else: |
| | rank = int(os.environ["RANK"]) |
| | world_size = int(os.environ["WORLD_SIZE"]) |
| |
|
| |
|
| | torch.cuda.set_device(rank % num_gpus) |
| |
|
| | dist.init_process_group( |
| | backend=backend, |
| | world_size=world_size, |
| | rank=rank, |
| | ) |
| |
|
| | return rank, world_size |
| |
|
| |
|
| |
|
| |
|
| | |
| | def setup_default_logging_wt_dir(save_path, flag_multigpus=False, l_level='INFO'): |
| |
|
| | if flag_multigpus: |
| | rank = dist.get_rank() |
| | if rank != 0: |
| | return |
| |
|
| | tmp_timestr = time_str(fmt='%Y_%m_%d_%H_%M_%S') |
| | new_log_path = os.path.join(save_path, tmp_timestr) |
| | os.makedirs(new_log_path, exist_ok=True) |
| | logger.add( |
| | os.path.join(new_log_path, f'{tmp_timestr}.log'), |
| | |
| | level=l_level, |
| | |
| | format='{level}|{time:YYYY-MM-DD HH:mm:ss}: {message}', |
| | |
| | |
| | enqueue=True, |
| | encoding='utf-8', |
| | ) |
| | return tmp_timestr |
| |
|
| |
|
| | |
| |
|
| | def seed_worker(worker_id): |
| | cur_seed = np.random.get_state()[1][0] |
| | cur_seed += worker_id |
| | np.random.seed(cur_seed) |
| | random.seed(cur_seed) |
| |
|