File size: 2,720 Bytes
dbd79bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
import logging
def get_device(number: int, logger: logging.Logger = None):
"""
Configures PyTorch to use a specified GPU by its index number,
or falls back to CPU if CUDA is not available.
Args:
number (int): The index number of the GPU to use.
logger (logging.Logger, optional): Logger for logging GPU info.
Returns:
torch.device: The selected torch device (GPU or CPU).
"""
# Fallback to CPU if CUDA is not available
if not torch.cuda.is_available():
if logger:
logger.warning("CUDA is not available. Falling back to CPU.")
return torch.device('cpu')
# Check if the specified GPU number is valid
if number >= torch.cuda.device_count() or number < 0:
raise ValueError(
f"GPU number {number} is not valid. Available GPU indices range from 0 to {torch.cuda.device_count() - 1}.")
# Clean up memory and stats
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
# Set and log device
torch.cuda.set_device(number)
if logger:
logger.info(f"PyTorch is now configured to use GPU {number}: {torch.cuda.get_device_name(number)}")
device_name = torch.cuda.get_device_name(number)
total_mem = torch.cuda.get_device_properties(number).total_memory / 1024 ** 2
mem_allocated = torch.cuda.memory_allocated(number) / 1024 ** 2
mem_reserved = torch.cuda.memory_reserved(number) / 1024 ** 2
max_allocated = torch.cuda.max_memory_allocated(number) / 1024 ** 2
max_reserved = torch.cuda.max_memory_reserved(number) / 1024 ** 2
logger.info(f"[GPU {number} - {device_name}] Memory Stats:")
logger.info(f" Total Memory : {total_mem:.2f} MB")
logger.info(f" Currently Allocated : {mem_allocated:.2f} MB")
logger.info(f" Currently Reserved : {mem_reserved:.2f} MB")
logger.info(f" Max Allocated : {max_allocated:.2f} MB")
logger.info(f" Max Reserved : {max_reserved:.2f} MB")
return torch.device(f'cuda:{number}')
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|