prompt-compiler-api / src /runtime /device_manager.py
JairoDanielMT's picture
Upload folder using huggingface_hub
4ef6c2b verified
Raw
History Blame Contribute Delete
1.6 kB
import os
class DeviceManager:
@staticmethod
def get_optimal_device() -> str:
"""
Detects the best available runtime (CUDA -> DirectML -> CPU).
Returns the execution provider for ONNX Runtime.
"""
# For simplicity, we assume if we are asked to use onnxruntime-gpu, we check for CUDA.
# onnxruntime provides get_available_providers()
try:
import onnxruntime as ort
providers = ort.get_available_providers()
if 'CUDAExecutionProvider' in providers:
return 'CUDAExecutionProvider'
if 'DmlExecutionProvider' in providers:
return 'DmlExecutionProvider'
return 'CPUExecutionProvider'
except ImportError:
return 'CPUExecutionProvider'
@staticmethod
def get_optimal_batch_size() -> int:
"""
Determines batch size based on available VRAM or CPU RAM.
"""
device = DeviceManager.get_optimal_device()
if device == 'CUDAExecutionProvider':
try:
import torch
if torch.cuda.is_available():
vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
if vram_gb >= 16:
return 1024
elif vram_gb >= 8:
return 512
else:
return 256
except ImportError:
return 256
# CPU or DML fallback
return 128