File size: 1,207 Bytes
3e72399 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | # /projects/bfoj/AttrLLM/loader/utils.py
import os
import random
import platform
from dataclasses import dataclass
from typing import Dict, Any, Optional
try:
import torch
except Exception:
torch = None
@dataclass
class DeviceInfo:
python_version: str
platform: str
cuda_available: bool
cuda_device: Optional[str]
torch_version: Optional[str]
def device_info() -> Dict[str, Any]:
"""Return a dict with basic runtime & device diagnostics."""
cuda_avail = bool(torch and torch.cuda.is_available())
return {
"python_version": platform.python_version(),
"platform": platform.platform(),
"torch_version": getattr(torch, "__version__", None),
"cuda_available": cuda_avail,
"cuda_device": (torch.cuda.get_device_name(0) if (cuda_avail and torch) else None),
}
def set_seed(seed: int) -> None:
"""Seed Python & Torch RNGs for reproducibility (best-effort)."""
random.seed(seed)
if torch:
try:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except Exception:
pass
os.environ["PYTHONHASHSEED"] = str(seed) |