|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import logging
|
|
|
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from monai.utils import RankFilter
|
|
|
|
|
|
|
|
|
def setup_logging(logger_name: str = "") -> logging.Logger:
|
|
|
"""
|
|
|
Setup the logging configuration.
|
|
|
|
|
|
Args:
|
|
|
logger_name (str): logger name.
|
|
|
|
|
|
Returns:
|
|
|
logging.Logger: Configured logger.
|
|
|
"""
|
|
|
logger = logging.getLogger(logger_name)
|
|
|
if dist.is_initialized():
|
|
|
logger.addFilter(RankFilter())
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
|
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
|
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
|
)
|
|
|
return logger
|
|
|
|
|
|
|
|
|
def load_config(env_config_path: str, model_config_path: str, model_def_path: str) -> argparse.Namespace:
|
|
|
"""
|
|
|
Load configuration from JSON files.
|
|
|
|
|
|
Args:
|
|
|
env_config_path (str): Path to the environment configuration file.
|
|
|
model_config_path (str): Path to the model configuration file.
|
|
|
model_def_path (str): Path to the model definition file.
|
|
|
|
|
|
Returns:
|
|
|
argparse.Namespace: Loaded configuration.
|
|
|
"""
|
|
|
args = argparse.Namespace()
|
|
|
|
|
|
with open(env_config_path, "r") as f:
|
|
|
env_config = json.load(f)
|
|
|
for k, v in env_config.items():
|
|
|
setattr(args, k, v)
|
|
|
|
|
|
with open(model_config_path, "r") as f:
|
|
|
model_config = json.load(f)
|
|
|
for k, v in model_config.items():
|
|
|
setattr(args, k, v)
|
|
|
|
|
|
with open(model_def_path, "r") as f:
|
|
|
model_def = json.load(f)
|
|
|
for k, v in model_def.items():
|
|
|
setattr(args, k, v)
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
def initialize_distributed(num_gpus: int) -> tuple:
|
|
|
"""
|
|
|
Initialize distributed training.
|
|
|
|
|
|
Returns:
|
|
|
tuple: local_rank, world_size, and device.
|
|
|
"""
|
|
|
if torch.cuda.is_available() and num_gpus > 1:
|
|
|
dist.init_process_group(backend="nccl", init_method="env://")
|
|
|
local_rank = dist.get_rank()
|
|
|
world_size = dist.get_world_size()
|
|
|
else:
|
|
|
local_rank = 0
|
|
|
world_size = 1
|
|
|
device = torch.device("cuda", local_rank)
|
|
|
torch.cuda.set_device(device)
|
|
|
return local_rank, world_size, device
|
|
|
|