# Copyright (c) 2025 Hanwen Jiang, Xuweiyi Chen. Adapted for WildRayZer from the RayZer project. from omegaconf import OmegaConf import argparse from easydict import EasyDict as edict import re import os import datetime import torch import torch.distributed as dist import numpy as np import random import yaml import wandb import shutil import copy from pathlib import Path import time #################Init Config Begins################# def process_overrides(overrides): """ Handle space around "=" """ # First, join all items with spaces to create a single string combined = " ".join(overrides) # Use regex to identify and fix patterns like 'param = value' to 'param=value' # This handles various spacing around the equals sign fixed_string = re.sub(r"(\S+)\s*=\s*(\S+)", r"\1=\2", combined) # Split the fixed string back into a list, preserving properly formatted args # We split on spaces that are not within a parameter=value pair processed = re.findall(r"[^\s=]+=\S+|\S+", fixed_string) return processed def init_config(): parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", required=True) parser.add_argument("overrides", nargs="*") # Capture all "key=value" args args = parser.parse_args() # Load base config config = OmegaConf.load(args.config) # Parse CLI overrides using OmegaConf's native CLI parser processed_overrides = process_overrides(args.overrides) cli_overrides = OmegaConf.from_cli(processed_overrides) # Merge configs (with type-safe automatic conversion) config = OmegaConf.merge(config, cli_overrides) # Convert to EasyDict if needed config = OmegaConf.to_container(config, resolve=True) config = edict(config) return config #################Init Config End################# def init_distributed(seed=42): """ Initialize distributed training environment and set random seeds for reproducibility. Args: seed (int): Random seed for PyTorch, NumPy, and Python's random module. Default is 42. Returns: edict: Dictionary with attribute access containing: - local_rank: GPU rank within the current node - global_rank: Global rank of the process - world_size: Total number of processes - device: The CUDA device assigned to this process - is_main_process: Flag to identify the main process - seed: The random seed used for this process """ global_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) dist.init_process_group( backend="nccl", timeout=datetime.timedelta(seconds=3600), device_id=device ) # Set random seeds # Each process gets a different seed derived from the base seed process_seed = seed + global_rank torch.manual_seed(process_seed) torch.cuda.manual_seed(process_seed) torch.cuda.manual_seed_all(process_seed) np.random.seed(process_seed) random.seed(process_seed) # Optional: For better performance torch.backends.cudnn.benchmark = True return edict( { "local_rank": local_rank, "global_rank": global_rank, "world_size": world_size, "device": device, "is_main_process": global_rank == 0, "seed": process_seed, } ) def local_backup_src_code( src_dir, dst_dir, max_size_MB=4.0, extension_to_backup=(".py", ".yaml", ".sh", ".bash", ".json"), exclude_dirs=("wandb", ".git", "checkpoints", "experiments"), verbose=True, ): """ Backup source code files with size limit check. Args: src_dir: Source directory to backup dst_dir: Destination directory for backups max_size_MB: Maximum total size allowed for backup in MB extension_to_backup: File extensions to include in backup exclude_dirs: Directories to exclude from backup verbose: Whether to print progress information Returns: tuple: (num_files_backed_up, total_size_in_bytes) Raises: ValueError: If total size exceeds max_size_MB """ start_time = time.time() src_path = Path(src_dir).resolve() dst_path = Path(dst_dir).resolve() # Convert to set for faster lookup extension_set = set(extension_to_backup) ignore_paths = {(src_path / d).resolve() for d in exclude_dirs} max_bytes = int(max_size_MB * 1024 * 1024) if not src_path.exists(): raise FileNotFoundError(f"Source directory does not exist: {src_path}") files = [] total_size = 0 for dirpath, dirnames, filenames in os.walk(src_path): current_path = Path(dirpath).resolve() # Skip excluded directories if ( any(parent in ignore_paths for parent in current_path.parents) or current_path in ignore_paths ): dirnames.clear() continue # Filter files by extension for filename in filenames: file_ext = os.path.splitext(filename)[1] if file_ext not in extension_set: continue src_file = current_path / filename rel_path = current_path.relative_to(src_path) dst_file = dst_path / rel_path / filename try: file_size = src_file.stat().st_size total_size += file_size files.append((src_file, dst_file, file_size)) except (FileNotFoundError, PermissionError) as e: if verbose: print(f"Warning: Could not access {src_file}: {e}") if total_size > max_bytes: if verbose: print(f"Size limit exceeded: {total_size / (1024*1024):.2f} MB > {max_size_MB} MB") print("Largest files:") for src_file, _, size in sorted(files, key=lambda x: x[2], reverse=True)[:5]: print(f"{src_file}: {size / 1024:.1f} KB") raise ValueError( f"Size limit exceeded: {total_size / (1024*1024):.2f} MB > {max_size_MB} MB" ) if verbose: print(f"Backing up {len(files)} files ({total_size / (1024*1024):.2f} MB)") dst_path.mkdir(parents=True, exist_ok=True) # Copy files successful_copies = 0 for src_file, dst_file, _ in files: try: dst_file.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(src_file, dst_file) successful_copies += 1 except Exception as e: if verbose: print(f"Error copying {src_file} to {dst_file}: {e}") elapsed_time = time.time() - start_time if verbose: print( f"Backup completed: {successful_copies}/{len(files)} files copied in {elapsed_time:.2f} seconds" ) return successful_copies, total_size def init_wandb_and_backup(config): # API key validation assert os.path.exists( config.training.api_key_path ), f"API key file does not exist: {config.training.api_key_path}" api_keys = edict(yaml.safe_load(open(config.training.api_key_path, "r"))) assert api_keys.wandb is not None, "Wandb API key not found in api key file" # WandB setup and login os.environ["WANDB_API_KEY"] = api_keys.wandb # WandB initialization config_copy = copy.deepcopy(config) wandb.init( project=config.training.wandb_project, name=config.training.wandb_exp_name, config=config_copy, ) # Source code backup cur_dir = os.path.dirname(os.path.realpath(__file__)) trgt_dir = os.path.join(config.training.checkpoint_dir, "src", os.path.basename(cur_dir)) os.makedirs(trgt_dir, exist_ok=True) extension_to_backup = (".py", ".yaml", ".sh", ".bash", ".json") exclude_dirs = ("wandb", ".git", "checkpoints", "experiments") # local_backup_src_code(cur_dir, trgt_dir, extension_to_backup=extension_to_backup, exclude_dirs=exclude_dirs) # Save config file config_save_path = os.path.join(config.training.checkpoint_dir, "config.yaml") with open(config_save_path, "w") as f: yaml.dump(dict(config), f) wandb.run.log_code( trgt_dir, include_fn=lambda path: path.endswith(extension_to_backup), )