"""Reproducibility utilities for deterministic training.""" import random import numpy as np import torch import os from typing import Optional import logging logger = logging.getLogger(__name__) def set_random_seeds(seed: int) -> None: """ Set random seeds for all libraries to ensure reproducibility. Args: seed: Random seed value """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) logger.info(f"Random seeds set to {seed}") def set_deterministic_mode(enabled: bool = True) -> None: """ Enable or disable deterministic mode for PyTorch operations. Note: Deterministic mode may reduce performance but ensures reproducibility. Args: enabled: Whether to enable deterministic mode """ if enabled: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # For PyTorch >= 1.8 if hasattr(torch, 'use_deterministic_algorithms'): torch.use_deterministic_algorithms(True) logger.info("Deterministic mode enabled") else: torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True if hasattr(torch, 'use_deterministic_algorithms'): torch.use_deterministic_algorithms(False) logger.info("Deterministic mode disabled") def get_environment_info() -> dict: """ Get information about the execution environment. Returns: Dictionary with environment information """ import sys import platform info = { 'python_version': sys.version, 'platform': platform.platform(), 'pytorch_version': torch.__version__, 'cuda_available': torch.cuda.is_available(), } if torch.cuda.is_available(): info['cuda_version'] = torch.version.cuda info['cudnn_version'] = torch.backends.cudnn.version() info['gpu_count'] = torch.cuda.device_count() info['gpu_names'] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] return info def log_environment_info() -> None: """Log environment information.""" info = get_environment_info() logger.info("=" * 80) logger.info("Environment Information:") logger.info("=" * 80) for key, value in info.items(): logger.info(f"{key}: {value}") logger.info("=" * 80) def setup_reproducibility(seed: int, deterministic: bool = False) -> None: """ Set up reproducibility by setting seeds and optionally enabling deterministic mode. Args: seed: Random seed value deterministic: Whether to enable deterministic mode """ set_random_seeds(seed) if deterministic: set_deterministic_mode(True) log_environment_info()