| import importlib |
| import torch |
| from typing import Any |
|
|
|
|
| def is_torch_npu_available(): |
| return importlib.util.find_spec("torch_npu") is not None |
|
|
|
|
| IS_CUDA_AVAILABLE = torch.cuda.is_available() |
| IS_NPU_AVAILABLE = is_torch_npu_available() and torch.npu.is_available() |
|
|
| if IS_NPU_AVAILABLE: |
| import torch_npu |
|
|
| torch.npu.config.allow_internal_format = False |
|
|
|
|
| def get_device_type() -> str: |
| """Get device type based on current machine, currently only support CPU, CUDA, NPU.""" |
| if IS_CUDA_AVAILABLE: |
| device = "cuda" |
| elif IS_NPU_AVAILABLE: |
| device = "npu" |
| else: |
| device = "cpu" |
|
|
| return device |
|
|
|
|
| def get_torch_device() -> Any: |
| """Get torch attribute based on device type, e.g. torch.cuda or torch.npu""" |
| device_name = get_device_type() |
|
|
| try: |
| return getattr(torch, device_name) |
| except AttributeError: |
| print(f"Device namespace '{device_name}' not found in torch, try to load 'torch.cuda'.") |
| return torch.cuda |
|
|
|
|
| def get_device_id() -> int: |
| """Get current device id based on device type.""" |
| return get_torch_device().current_device() |
|
|
|
|
| def get_device_name() -> str: |
| """Get current device name based on device type.""" |
| return f"{get_device_type()}:{get_device_id()}" |
|
|
|
|
| def synchronize() -> None: |
| """Execute torch synchronize operation.""" |
| get_torch_device().synchronize() |
|
|
|
|
| def empty_cache() -> None: |
| """Execute torch empty cache operation.""" |
| get_torch_device().empty_cache() |
|
|
|
|
| def get_nccl_backend() -> str: |
| """Return distributed communication backend type based on device type.""" |
| if IS_CUDA_AVAILABLE: |
| return "nccl" |
| elif IS_NPU_AVAILABLE: |
| return "hccl" |
| else: |
| raise RuntimeError(f"No available distributed communication backend found on device type {get_device_type()}.") |
|
|
|
|
| def enable_high_precision_for_bf16(): |
| """ |
| Set high accumulation dtype for matmul and reduction. |
| """ |
| if IS_CUDA_AVAILABLE: |
| torch.backends.cuda.matmul.allow_tf32 = False |
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False |
|
|
| if IS_NPU_AVAILABLE: |
| torch.npu.matmul.allow_tf32 = False |
| torch.npu.matmul.allow_bf16_reduced_precision_reduction = False |
|
|
|
|
| def parse_device_type(device): |
| if isinstance(device, str): |
| if device.startswith("cuda"): |
| return "cuda" |
| elif device.startswith("npu"): |
| return "npu" |
| else: |
| return "cpu" |
| elif isinstance(device, torch.device): |
| return device.type |
|
|
|
|
| def parse_nccl_backend(device_type): |
| if device_type == "cuda": |
| return "nccl" |
| elif device_type == "npu": |
| return "hccl" |
| else: |
| raise RuntimeError(f"No available distributed communication backend found on device type {device_type}.") |
|
|
|
|
| def get_available_device_type(): |
| return get_device_type() |
|
|