| """Hardware detection and runtime configuration.""" |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| import ctypes |
| import os |
| import platform |
|
|
| import torch |
|
|
| from train.distributed import get_training_strategy |
|
|
| try: |
| import psutil |
| except ImportError: |
| psutil = None |
|
|
|
|
| @dataclass |
| class HardwareConfig: |
| """Detect hardware and derive runtime decisions.""" |
|
|
| model_size_b: float |
| context_length: int |
|
|
| def __post_init__(self) -> None: |
| self.device, self.dtype = self._detect_device_dtype() |
| self.n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 |
| self.vram_gb = self._get_vram() |
| self.ram_gb = self._get_ram_gb() |
| self.strategy = get_training_strategy(self.model_size_b) |
| self.micro_batch = self._pick_micro_batch() |
| self.grad_accum = self._pick_grad_accum() |
| self.use_amp = self.device != "cpu" |
| self.use_flash_attn = self.device == "cuda" |
| self.use_qlora = False |
|
|
| def _detect_device_dtype(self) -> tuple[str, torch.dtype]: |
| if torch.cuda.is_available(): |
| dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
| return "cuda", dtype |
| if torch.backends.mps.is_available(): |
| return "mps", torch.bfloat16 |
| return "cpu", torch.float32 |
|
|
| def _get_vram(self) -> float: |
| if not torch.cuda.is_available(): |
| return 0.0 |
| return torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
| def _get_ram_gb(self) -> float: |
| if psutil is not None: |
| return psutil.virtual_memory().total / 1e9 |
| if platform.system() == "Windows": |
| kernel32 = ctypes.windll.kernel32 |
| c_ulonglong = ctypes.c_ulonglong |
| mem_kb = c_ulonglong() |
| kernel32.GetPhysicallyInstalledSystemMemory(ctypes.byref(mem_kb)) |
| return (mem_kb.value * 1024) / 1e9 |
| if hasattr(os, "sysconf"): |
| pages = os.sysconf("SC_PHYS_PAGES") |
| page_size = os.sysconf("SC_PAGE_SIZE") |
| return (pages * page_size) / 1e9 |
| return 0.0 |
|
|
| def _pick_micro_batch(self) -> int: |
| if self.device == "cpu": |
| return 1 |
| if self.vram_gb >= 80: |
| return 8 |
| if self.vram_gb >= 40: |
| return 4 |
| if self.vram_gb >= 24: |
| return 2 |
| return 1 |
|
|
| def _pick_grad_accum(self) -> int: |
| target_tokens = 2_000_000 |
| tokens_per_micro = self.micro_batch * self.context_length * max(self.n_gpus, 1) |
| grad_accum = max(1, target_tokens // max(tokens_per_micro, 1)) |
| if self.device == "cpu": |
| return min(8, grad_accum) |
| return grad_accum |
|
|
| def summary(self) -> dict[str, object]: |
| """Return a JSON-safe hardware summary.""" |
| effective_batch = self.micro_batch * self.grad_accum * self.context_length * max(self.n_gpus, 1) |
| return { |
| "device": self.device, |
| "dtype": str(self.dtype), |
| "n_gpus": self.n_gpus, |
| "vram_gb": round(self.vram_gb, 2), |
| "ram_gb": round(self.ram_gb, 2), |
| "strategy": self.strategy, |
| "micro_batch": self.micro_batch, |
| "grad_accum": self.grad_accum, |
| "effective_batch_tokens": effective_batch, |
| "use_flash_attn": self.use_flash_attn, |
| } |
|
|