sage / train /hardware.py
sage002's picture
SAGE model repository : Updating some model checkpoints
3d2114e verified
"""Hardware detection and runtime configuration."""
from __future__ import annotations
from dataclasses import dataclass
import ctypes
import os
import platform
import torch
from train.distributed import get_training_strategy
try:
import psutil # type: ignore
except ImportError: # pragma: no cover - optional dependency
psutil = None
@dataclass
class HardwareConfig:
"""Detect hardware and derive runtime decisions."""
model_size_b: float
context_length: int
def __post_init__(self) -> None:
self.device, self.dtype = self._detect_device_dtype()
self.n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
self.vram_gb = self._get_vram()
self.ram_gb = self._get_ram_gb()
self.strategy = get_training_strategy(self.model_size_b)
self.micro_batch = self._pick_micro_batch()
self.grad_accum = self._pick_grad_accum()
self.use_amp = self.device != "cpu"
self.use_flash_attn = self.device == "cuda"
self.use_qlora = False
def _detect_device_dtype(self) -> tuple[str, torch.dtype]:
if torch.cuda.is_available():
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
return "cuda", dtype
if torch.backends.mps.is_available():
return "mps", torch.bfloat16
return "cpu", torch.float32
def _get_vram(self) -> float:
if not torch.cuda.is_available():
return 0.0
return torch.cuda.get_device_properties(0).total_memory / 1e9
def _get_ram_gb(self) -> float:
if psutil is not None:
return psutil.virtual_memory().total / 1e9
if platform.system() == "Windows":
kernel32 = ctypes.windll.kernel32
c_ulonglong = ctypes.c_ulonglong
mem_kb = c_ulonglong()
kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem_kb))
return (mem_kb.value * 1024) / 1e9
if hasattr(os, "sysconf"):
pages = os.sysconf("SC_PHYS_PAGES")
page_size = os.sysconf("SC_PAGE_SIZE")
return (pages * page_size) / 1e9
return 0.0
def _pick_micro_batch(self) -> int:
if self.device == "cpu":
return 1
if self.vram_gb >= 80:
return 8
if self.vram_gb >= 40:
return 4
if self.vram_gb >= 24:
return 2
return 1
def _pick_grad_accum(self) -> int:
target_tokens = 2_000_000
tokens_per_micro = self.micro_batch * self.context_length * max(self.n_gpus, 1)
grad_accum = max(1, target_tokens // max(tokens_per_micro, 1))
if self.device == "cpu":
return min(8, grad_accum)
return grad_accum
def summary(self) -> dict[str, object]:
"""Return a JSON-safe hardware summary."""
effective_batch = self.micro_batch * self.grad_accum * self.context_length * max(self.n_gpus, 1)
return {
"device": self.device,
"dtype": str(self.dtype),
"n_gpus": self.n_gpus,
"vram_gb": round(self.vram_gb, 2),
"ram_gb": round(self.ram_gb, 2),
"strategy": self.strategy,
"micro_batch": self.micro_batch,
"grad_accum": self.grad_accum,
"effective_batch_tokens": effective_batch,
"use_flash_attn": self.use_flash_attn,
}