File size: 2,034 Bytes
e9f9fd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
from enum import Enum
from .device_id import DeviceId
import logging

#NOTE:  This must be called first before any torch imports in order to work properly!

class DeviceException(Exception):
    pass

class _Device:
    def __init__(self):
        self._current_device = DeviceId.CPU
        self._backend = 'cpu'
        self._init_device()

    def _init_device(self):
        # Check for Intel Extension for PyTorch
        try:
            import intel_extension_for_pytorch as ipex
            if torch.xpu.is_available():
                self._backend = 'xpu'
                return
        except ImportError:
            pass

        # Check for CUDA
        if torch.cuda.is_available():
            self._backend = 'cuda'
            return

        self._backend = 'cpu'

    def is_gpu(self):
        ''' Returns `True` if the current device is GPU (CUDA or XPU), `False` otherwise. '''
        return self.current() is not DeviceId.CPU
  
    def current(self):
        return self._current_device

    def set(self, device:DeviceId):     
        if device == DeviceId.CPU:
            os.environ['CUDA_VISIBLE_DEVICES']=''
            self._current_device = DeviceId.CPU
        else:
            # Handle GPU selection
            if self._backend == 'cuda':
                os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
                torch.backends.cudnn.benchmark=True
            elif self._backend == 'xpu':
                # For XPU, we might need different env vars or just rely on index
                # Currently just setting the device ID
                pass
            
            self._current_device = device
            
        return device

    def get_torch_device(self):
        if self._current_device == DeviceId.CPU:
            return torch.device('cpu')
        
        if self._backend == 'cuda':
            return torch.device('cuda')
        elif self._backend == 'xpu':
            return torch.device('xpu')
        
        return torch.device('cpu')