Spaces:
Sleeping
Sleeping
| import os | |
| class DeviceManager: | |
| def get_optimal_device() -> str: | |
| """ | |
| Detects the best available runtime (CUDA -> DirectML -> CPU). | |
| Returns the execution provider for ONNX Runtime. | |
| """ | |
| # For simplicity, we assume if we are asked to use onnxruntime-gpu, we check for CUDA. | |
| # onnxruntime provides get_available_providers() | |
| try: | |
| import onnxruntime as ort | |
| providers = ort.get_available_providers() | |
| if 'CUDAExecutionProvider' in providers: | |
| return 'CUDAExecutionProvider' | |
| if 'DmlExecutionProvider' in providers: | |
| return 'DmlExecutionProvider' | |
| return 'CPUExecutionProvider' | |
| except ImportError: | |
| return 'CPUExecutionProvider' | |
| def get_optimal_batch_size() -> int: | |
| """ | |
| Determines batch size based on available VRAM or CPU RAM. | |
| """ | |
| device = DeviceManager.get_optimal_device() | |
| if device == 'CUDAExecutionProvider': | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) | |
| if vram_gb >= 16: | |
| return 1024 | |
| elif vram_gb >= 8: | |
| return 512 | |
| else: | |
| return 256 | |
| except ImportError: | |
| return 256 | |
| # CPU or DML fallback | |
| return 128 | |