File size: 3,419 Bytes
d7ecc62 ae46efa d7ecc62 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | """GPU auto-detection and optimal training settings.
Auto-detects NVIDIA vs AMD GPUs and configures torch.compile and SDPA
backend for best performance. ROCm's flash attention backward has stride
mismatches with torch.compile, so we fall back to the MATH SDPA backend
on AMD GPUs (still ~30% faster than no compile at all).
"""
def is_rocm() -> bool:
"""Check if the current CUDA device is AMD/ROCm."""
import torch
if not torch.cuda.is_available():
return False
name = torch.cuda.get_device_name(0).lower()
return "radeon" in name or "rx " in name or "mi " in name or "mi3" in name
def configure_gpu(
device: str = "cuda",
*,
no_compile: bool = False,
no_amp: bool = False,
sdpa_math: bool = False,
) -> dict:
"""Auto-detect GPU and return optimal training settings.
Returns a dict with:
use_compile: bool — whether to torch.compile the forward pass
use_amp: bool — whether to use automatic mixed precision
sdpa_backend: SDPBackend | None — SDPA backend override (None = default)
CLI flags (no_compile, no_amp, sdpa_math) act as overrides. When not
set, the function picks the fastest settings for the detected GPU:
NVIDIA: compile + AMP + flash attention (default SDPA)
AMD: compile + AMP + MATH SDPA (avoids flash attn backward bug)
CPU: no compile, no AMP
"""
import torch
from torch.nn.attention import SDPBackend
is_cuda = device.startswith("cuda") and torch.cuda.is_available()
rocm = is_cuda and is_rocm()
# Defaults: compile and AMP on for CUDA
use_compile = is_cuda and not no_compile
use_amp = is_cuda and not no_amp
# SDPA backend: use MATH on ROCm (flash attn backward is broken with compile)
if sdpa_math:
sdpa_backend = SDPBackend.MATH
elif rocm and use_compile:
sdpa_backend = SDPBackend.MATH
else:
sdpa_backend = None
# Log what we're doing
if is_cuda:
gpu_name = torch.cuda.get_device_name(0)
platform = "ROCm" if rocm else "CUDA"
print(f"GPU: {gpu_name} ({platform})")
else:
print("GPU: none (CPU mode)")
if use_compile:
print(f" torch.compile: enabled (inductor)")
else:
print(f" torch.compile: disabled")
if use_amp:
print(f" AMP: enabled (fp16)")
else:
print(f" AMP: disabled")
if sdpa_backend is not None:
print(f" SDPA backend: {sdpa_backend.name}")
else:
print(f" SDPA backend: default")
return {
"use_compile": use_compile,
"use_amp": use_amp,
"sdpa_backend": sdpa_backend,
}
def apply_gpu_config(config: dict, model_module, forward_fn):
"""Apply GPU config: set SDPA backend and optionally compile forward_fn.
Args:
config: dict from configure_gpu()
model_module: the pawn.model module (to set SDPA_BACKEND)
forward_fn: the forward function to compile (e.g. model.forward_hidden)
Returns:
The (possibly compiled) forward function.
"""
# IMPORTANT: Set SDPA_BACKEND before torch.compile — compiled code
# captures the backend at trace time.
if config["sdpa_backend"] is not None:
model_module.SDPA_BACKEND = config["sdpa_backend"]
if config["use_compile"]:
import torch
forward_fn = torch.compile(forward_fn)
return forward_fn
|