File size: 2,720 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
import logging


def get_device(number: int, logger: logging.Logger = None):
    """
    Configures PyTorch to use a specified GPU by its index number,
    or falls back to CPU if CUDA is not available.

    Args:
        number (int): The index number of the GPU to use.
        logger (logging.Logger, optional): Logger for logging GPU info.

    Returns:
        torch.device: The selected torch device (GPU or CPU).
    """
    # Fallback to CPU if CUDA is not available
    if not torch.cuda.is_available():
        if logger:
            logger.warning("CUDA is not available. Falling back to CPU.")
        return torch.device('cpu')

    # Check if the specified GPU number is valid
    if number >= torch.cuda.device_count() or number < 0:
        raise ValueError(
            f"GPU number {number} is not valid. Available GPU indices range from 0 to {torch.cuda.device_count() - 1}.")

    # Clean up memory and stats
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()

    # Set and log device
    torch.cuda.set_device(number)
    if logger:
        logger.info(f"PyTorch is now configured to use GPU {number}: {torch.cuda.get_device_name(number)}")

        device_name = torch.cuda.get_device_name(number)
        total_mem = torch.cuda.get_device_properties(number).total_memory / 1024 ** 2
        mem_allocated = torch.cuda.memory_allocated(number) / 1024 ** 2
        mem_reserved = torch.cuda.memory_reserved(number) / 1024 ** 2
        max_allocated = torch.cuda.max_memory_allocated(number) / 1024 ** 2
        max_reserved = torch.cuda.max_memory_reserved(number) / 1024 ** 2

        logger.info(f"[GPU {number} - {device_name}] Memory Stats:")
        logger.info(f"  Total Memory      : {total_mem:.2f} MB")
        logger.info(f"  Currently Allocated : {mem_allocated:.2f} MB")
        logger.info(f"  Currently Reserved  : {mem_reserved:.2f} MB")
        logger.info(f"  Max Allocated       : {max_allocated:.2f} MB")
        logger.info(f"  Max Reserved        : {max_reserved:.2f} MB")

    return torch.device(f'cuda:{number}')
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #