# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # # # This file was created by: Alberto Palomo Alonso # # Universidad de Alcalá - Escuela Politécnica Superior # # # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # Import statements: import torch import logging def get_device(number: int, logger: logging.Logger = None): """ Configures PyTorch to use a specified GPU by its index number, or falls back to CPU if CUDA is not available. Args: number (int): The index number of the GPU to use. logger (logging.Logger, optional): Logger for logging GPU info. Returns: torch.device: The selected torch device (GPU or CPU). """ # Fallback to CPU if CUDA is not available if not torch.cuda.is_available(): if logger: logger.warning("CUDA is not available. Falling back to CPU.") return torch.device('cpu') # Check if the specified GPU number is valid if number >= torch.cuda.device_count() or number < 0: raise ValueError( f"GPU number {number} is not valid. Available GPU indices range from 0 to {torch.cuda.device_count() - 1}.") # Clean up memory and stats torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() torch.cuda.reset_accumulated_memory_stats() # Set and log device torch.cuda.set_device(number) if logger: logger.info(f"PyTorch is now configured to use GPU {number}: {torch.cuda.get_device_name(number)}") device_name = torch.cuda.get_device_name(number) total_mem = torch.cuda.get_device_properties(number).total_memory / 1024 ** 2 mem_allocated = torch.cuda.memory_allocated(number) / 1024 ** 2 mem_reserved = torch.cuda.memory_reserved(number) / 1024 ** 2 max_allocated = torch.cuda.max_memory_allocated(number) / 1024 ** 2 max_reserved = torch.cuda.max_memory_reserved(number) / 1024 ** 2 logger.info(f"[GPU {number} - {device_name}] Memory Stats:") logger.info(f" Total Memory : {total_mem:.2f} MB") logger.info(f" Currently Allocated : {mem_allocated:.2f} MB") logger.info(f" Currently Reserved : {mem_reserved:.2f} MB") logger.info(f" Max Allocated : {max_allocated:.2f} MB") logger.info(f" Max Reserved : {max_reserved:.2f} MB") return torch.device(f'cuda:{number}') # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - # # END OF FILE # # - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #