| |
| import os |
| import random |
| import platform |
| from dataclasses import dataclass |
| from typing import Dict, Any, Optional |
|
|
| try: |
| import torch |
| except Exception: |
| torch = None |
|
|
| @dataclass |
| class DeviceInfo: |
| python_version: str |
| platform: str |
| cuda_available: bool |
| cuda_device: Optional[str] |
| torch_version: Optional[str] |
|
|
| def device_info() -> Dict[str, Any]: |
| """Return a dict with basic runtime & device diagnostics.""" |
| cuda_avail = bool(torch and torch.cuda.is_available()) |
| return { |
| "python_version": platform.python_version(), |
| "platform": platform.platform(), |
| "torch_version": getattr(torch, "__version__", None), |
| "cuda_available": cuda_avail, |
| "cuda_device": (torch.cuda.get_device_name(0) if (cuda_avail and torch) else None), |
| } |
|
|
| def set_seed(seed: int) -> None: |
| """Seed Python & Torch RNGs for reproducibility (best-effort).""" |
| random.seed(seed) |
| if torch: |
| try: |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| except Exception: |
| pass |
| os.environ["PYTHONHASHSEED"] = str(seed) |