| | import torch |
| | import os, sys |
| |
|
| | if sys.platform == "darwin": |
| | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" |
| | now_dir = os.getcwd() |
| | sys.path.append(now_dir) |
| | from .logger.log import get_logger |
| |
|
| | logger = get_logger("gpu") |
| |
|
| |
|
| | def select_device(min_memory=2047, experimental=False): |
| | if torch.cuda.is_available(): |
| | selected_gpu = 0 |
| | max_free_memory = -1 |
| | for i in range(torch.cuda.device_count()): |
| | props = torch.cuda.get_device_properties(i) |
| | free_memory = props.total_memory - torch.cuda.memory_reserved(i) |
| | if max_free_memory < free_memory: |
| | selected_gpu = i |
| | max_free_memory = free_memory |
| | free_memory_mb = max_free_memory / (1024 * 1024) |
| | if free_memory_mb < min_memory: |
| | logger.get_logger().warning( |
| | f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." |
| | ) |
| | device = torch.device("cpu") |
| | else: |
| | device = torch.device(f"cuda:{selected_gpu}") |
| | elif torch.backends.mps.is_available(): |
| | """ |
| | Currently MPS is slower than CPU while needs more memory and core utility, |
| | so only enable this for experimental use. |
| | """ |
| | if experimental: |
| | |
| | logger.warn("experimantal: found apple GPU, using MPS.") |
| | device = torch.device("mps") |
| | else: |
| | logger.info("found Apple GPU, but use CPU.") |
| | device = torch.device("cpu") |
| | else: |
| | logger.warning("no GPU found, use CPU instead") |
| | device = torch.device("cpu") |
| |
|
| | return device |
| |
|