File size: 3,383 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Global seed management utilities for reproducible experiments.

This module provides a centralized way to set random seeds across all
random number generators used in the platform (torch, numpy, scikit-learn, random).
"""

import os
import time
import random
import numpy as np
from typing import Optional

# Global variable to store the current seed
_GLOBAL_SEED: Optional[int] = None


def get_global_seed() -> Optional[int]:
    """
    Get the currently set global seed.
    
    Returns:
        The current global seed value, or None if not set.
    """
    return _GLOBAL_SEED


def set_cublas_workspace_config():
    """Set CUBLAS workspace config to an allowed deterministic value.

    Must be set BEFORE importing torch. Valid values (per NVIDIA docs):
      - ":4096:8" (recommended)
      - ":16:8"   (minimal workspace)
    """
    # Only set if not already provided by the environment/user
    if "CUBLAS_WORKSPACE_CONFIG" not in os.environ:
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


def seed_worker(worker_id: int):
    """Use with torch.utils.data.DataLoader(worker_init_fn=seed_worker) to sync NumPy/random per-worker."""
    import torch
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def dataloader_generator(seed: Optional[int]):
    """
    Use with torch.utils.data.DataLoader(generator=dataloader_generator(seed)) to sync NumPy/random per-worker.
    """
    import torch
    
    if seed is None:
        seed = set_global_seed()

    g = torch.Generator()
    g.manual_seed(seed)
    return g


def set_global_seed(seed: Optional[int] = None) -> int:
    """
    Set the global random seed for all random number generators.
    
    This function sets seeds for:
    - Python's random module
    - NumPy
    - PyTorch
    
    Args:
        seed: The seed value to use. If None, uses current timestamp.
    
    Returns:
        The seed value that was set.
    """    
    # Generate seed from current time if not provided
    if seed is None:
        seed = int(time.time() * 1000000) % (2**31)
    
    # Store the global seed
    global _GLOBAL_SEED
    _GLOBAL_SEED = seed
    
    random.seed(seed)
    np.random.seed(seed)

    # Import torch lazily to avoid initializing CUDA before env is set elsewhere
    import torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    return seed


def set_determinism():
    # set_cublas_workspace_config() must happen BEFORE importing torch
    #set_cublas_workspace_config()

    # Import torch only after the env var has been set
    import torch

    # Set deterministic behavior for reproducibility
    # Note: This can significantly slow down operations. Only use if you need to be 100% reproducible
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if hasattr(torch, 'use_deterministic_algorithms'):
        try:
            torch.use_deterministic_algorithms(True, warn_only=False)
        except Exception as e:
            print(f'torch.use_deterministic_algorithms is not available: {e}')
            # print torch version
            print(f'torch version: {torch.__version__}')
            print('Make sure you are using the correct version of torch')