| """GPU auto-detection and optimal training settings. |
| |
| Auto-detects NVIDIA vs AMD GPUs and configures torch.compile and SDPA |
| backend for best performance. ROCm's flash attention backward has stride |
| mismatches with torch.compile, so we fall back to the MATH SDPA backend |
| on AMD GPUs (still ~30% faster than no compile at all). |
| """ |
|
|
|
|
| def is_rocm() -> bool: |
| """Check if the current CUDA device is AMD/ROCm.""" |
| import torch |
| if not torch.cuda.is_available(): |
| return False |
| name = torch.cuda.get_device_name(0).lower() |
| return "radeon" in name or "rx " in name or "mi " in name or "mi3" in name |
|
|
|
|
| def configure_gpu( |
| device: str = "cuda", |
| *, |
| no_compile: bool = False, |
| no_amp: bool = False, |
| sdpa_math: bool = False, |
| ) -> dict: |
| """Auto-detect GPU and return optimal training settings. |
| |
| Returns a dict with: |
| use_compile: bool — whether to torch.compile the forward pass |
| use_amp: bool — whether to use automatic mixed precision |
| sdpa_backend: SDPBackend | None — SDPA backend override (None = default) |
| |
| CLI flags (no_compile, no_amp, sdpa_math) act as overrides. When not |
| set, the function picks the fastest settings for the detected GPU: |
| |
| NVIDIA: compile + AMP + flash attention (default SDPA) |
| AMD: compile + AMP + MATH SDPA (avoids flash attn backward bug) |
| CPU: no compile, no AMP |
| """ |
| import torch |
| from torch.nn.attention import SDPBackend |
|
|
| is_cuda = device.startswith("cuda") and torch.cuda.is_available() |
| rocm = is_cuda and is_rocm() |
|
|
| |
| use_compile = is_cuda and not no_compile |
| use_amp = is_cuda and not no_amp |
|
|
| |
| if sdpa_math: |
| sdpa_backend = SDPBackend.MATH |
| elif rocm and use_compile: |
| sdpa_backend = SDPBackend.MATH |
| else: |
| sdpa_backend = None |
|
|
| |
| if is_cuda: |
| gpu_name = torch.cuda.get_device_name(0) |
| platform = "ROCm" if rocm else "CUDA" |
| print(f"GPU: {gpu_name} ({platform})") |
| else: |
| print("GPU: none (CPU mode)") |
|
|
| if use_compile: |
| print(f" torch.compile: enabled (inductor)") |
| else: |
| print(f" torch.compile: disabled") |
|
|
| if use_amp: |
| print(f" AMP: enabled (fp16)") |
| else: |
| print(f" AMP: disabled") |
|
|
| if sdpa_backend is not None: |
| print(f" SDPA backend: {sdpa_backend.name}") |
| else: |
| print(f" SDPA backend: default") |
|
|
| return { |
| "use_compile": use_compile, |
| "use_amp": use_amp, |
| "sdpa_backend": sdpa_backend, |
| } |
|
|
|
|
| def apply_gpu_config(config: dict, model_module, forward_fn): |
| """Apply GPU config: set SDPA backend and optionally compile forward_fn. |
| |
| Args: |
| config: dict from configure_gpu() |
| model_module: the pawn.model module (to set SDPA_BACKEND) |
| forward_fn: the forward function to compile (e.g. model.forward_hidden) |
| |
| Returns: |
| The (possibly compiled) forward function. |
| """ |
| |
| |
| if config["sdpa_backend"] is not None: |
| model_module.SDPA_BACKEND = config["sdpa_backend"] |
|
|
| if config["use_compile"]: |
| import torch |
| forward_fn = torch.compile(forward_fn) |
|
|
| return forward_fn |
|
|