File size: 3,419 Bytes
d7ecc62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae46efa
 
d7ecc62
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""GPU auto-detection and optimal training settings.

Auto-detects NVIDIA vs AMD GPUs and configures torch.compile and SDPA
backend for best performance. ROCm's flash attention backward has stride
mismatches with torch.compile, so we fall back to the MATH SDPA backend
on AMD GPUs (still ~30% faster than no compile at all).
"""


def is_rocm() -> bool:
    """Check if the current CUDA device is AMD/ROCm."""
    import torch
    if not torch.cuda.is_available():
        return False
    name = torch.cuda.get_device_name(0).lower()
    return "radeon" in name or "rx " in name or "mi " in name or "mi3" in name


def configure_gpu(
    device: str = "cuda",
    *,
    no_compile: bool = False,
    no_amp: bool = False,
    sdpa_math: bool = False,
) -> dict:
    """Auto-detect GPU and return optimal training settings.

    Returns a dict with:
        use_compile: bool — whether to torch.compile the forward pass
        use_amp: bool — whether to use automatic mixed precision
        sdpa_backend: SDPBackend | None — SDPA backend override (None = default)

    CLI flags (no_compile, no_amp, sdpa_math) act as overrides. When not
    set, the function picks the fastest settings for the detected GPU:

        NVIDIA: compile + AMP + flash attention (default SDPA)
        AMD:    compile + AMP + MATH SDPA (avoids flash attn backward bug)
        CPU:    no compile, no AMP
    """
    import torch
    from torch.nn.attention import SDPBackend

    is_cuda = device.startswith("cuda") and torch.cuda.is_available()
    rocm = is_cuda and is_rocm()

    # Defaults: compile and AMP on for CUDA
    use_compile = is_cuda and not no_compile
    use_amp = is_cuda and not no_amp

    # SDPA backend: use MATH on ROCm (flash attn backward is broken with compile)
    if sdpa_math:
        sdpa_backend = SDPBackend.MATH
    elif rocm and use_compile:
        sdpa_backend = SDPBackend.MATH
    else:
        sdpa_backend = None

    # Log what we're doing
    if is_cuda:
        gpu_name = torch.cuda.get_device_name(0)
        platform = "ROCm" if rocm else "CUDA"
        print(f"GPU: {gpu_name} ({platform})")
    else:
        print("GPU: none (CPU mode)")

    if use_compile:
        print(f"  torch.compile: enabled (inductor)")
    else:
        print(f"  torch.compile: disabled")

    if use_amp:
        print(f"  AMP: enabled (fp16)")
    else:
        print(f"  AMP: disabled")

    if sdpa_backend is not None:
        print(f"  SDPA backend: {sdpa_backend.name}")
    else:
        print(f"  SDPA backend: default")

    return {
        "use_compile": use_compile,
        "use_amp": use_amp,
        "sdpa_backend": sdpa_backend,
    }


def apply_gpu_config(config: dict, model_module, forward_fn):
    """Apply GPU config: set SDPA backend and optionally compile forward_fn.

    Args:
        config: dict from configure_gpu()
        model_module: the pawn.model module (to set SDPA_BACKEND)
        forward_fn: the forward function to compile (e.g. model.forward_hidden)

    Returns:
        The (possibly compiled) forward function.
    """
    # IMPORTANT: Set SDPA_BACKEND before torch.compile — compiled code
    # captures the backend at trace time.
    if config["sdpa_backend"] is not None:
        model_module.SDPA_BACKEND = config["sdpa_backend"]

    if config["use_compile"]:
        import torch
        forward_fn = torch.compile(forward_fn)

    return forward_fn