DeOldify / deoldify /_device.py
thookham's picture
Initial commit for Hugging Face sync (Clean History)
e9f9fd3
import os
import torch
from enum import Enum
from .device_id import DeviceId
import logging
#NOTE: This must be called first before any torch imports in order to work properly!
class DeviceException(Exception):
pass
class _Device:
def __init__(self):
self._current_device = DeviceId.CPU
self._backend = 'cpu'
self._init_device()
def _init_device(self):
# Check for Intel Extension for PyTorch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
self._backend = 'xpu'
return
except ImportError:
pass
# Check for CUDA
if torch.cuda.is_available():
self._backend = 'cuda'
return
self._backend = 'cpu'
def is_gpu(self):
''' Returns `True` if the current device is GPU (CUDA or XPU), `False` otherwise. '''
return self.current() is not DeviceId.CPU
def current(self):
return self._current_device
def set(self, device:DeviceId):
if device == DeviceId.CPU:
os.environ['CUDA_VISIBLE_DEVICES']=''
self._current_device = DeviceId.CPU
else:
# Handle GPU selection
if self._backend == 'cuda':
os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
torch.backends.cudnn.benchmark=True
elif self._backend == 'xpu':
# For XPU, we might need different env vars or just rely on index
# Currently just setting the device ID
pass
self._current_device = device
return device
def get_torch_device(self):
if self._current_device == DeviceId.CPU:
return torch.device('cpu')
if self._backend == 'cuda':
return torch.device('cuda')
elif self._backend == 'xpu':
return torch.device('xpu')
return torch.device('cpu')