File size: 319 Bytes
f880dff
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.

import torch


def get_auto_device():
    try:
        import torch_npu
        has_npu = torch_npu.npu.is_available()
    except ImportError:
        has_npu = False

    return 'cuda' if torch.cuda.is_available() else 'npu' if has_npu else 'cpu'