import os class DeviceManager: @staticmethod def get_optimal_device() -> str: """ Detects the best available runtime (CUDA -> DirectML -> CPU). Returns the execution provider for ONNX Runtime. """ # For simplicity, we assume if we are asked to use onnxruntime-gpu, we check for CUDA. # onnxruntime provides get_available_providers() try: import onnxruntime as ort providers = ort.get_available_providers() if 'CUDAExecutionProvider' in providers: return 'CUDAExecutionProvider' if 'DmlExecutionProvider' in providers: return 'DmlExecutionProvider' return 'CPUExecutionProvider' except ImportError: return 'CPUExecutionProvider' @staticmethod def get_optimal_batch_size() -> int: """ Determines batch size based on available VRAM or CPU RAM. """ device = DeviceManager.get_optimal_device() if device == 'CUDAExecutionProvider': try: import torch if torch.cuda.is_available(): vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) if vram_gb >= 16: return 1024 elif vram_gb >= 8: return 512 else: return 256 except ImportError: return 256 # CPU or DML fallback return 128