File size: 3,573 Bytes
9823a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
import subprocess
from typing import List
from thefuzz import process
from tvm.target import Target
from tvm.target.tag import list_tags

import logging

logger = logging.getLogger(__name__)

TARGET_MISSING_ERROR = (
    "TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=<target>`, "
    "where <target> is one of the available targets can be found in the output of `tools/get_available_targets.py`."
)

# Nvidia produces non-public oem gpu models that are part of drivers but not mapped to correct tvm target
# Remap list to match the oem model name to the closest public model name
NVIDIA_GPU_REMAP = {
    "NVIDIA PG506-230": "NVIDIA A100",
    "NVIDIA PG506-232": "NVIDIA A100",
}

def get_gpu_model_from_nvidia_smi(gpu_id: int = 0):
    """
    Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU.

    Returns:
        str: The name of the GPU, or None if 'nvidia-smi' command fails.
    """
    try:
        # Execute nvidia-smi command to get the GPU name
        output = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"],
            encoding="utf-8",
        ).strip()
    except subprocess.CalledProcessError as e:
        logger.info("nvidia-smi failed with error: %s", e)
        return None

    gpus = output.split("\n")

    # for multiple gpus, CUDA_DEVICE_ORDER=PCI_BUS_ID must be set to match nvidia-smi or else wrong
    # gpu is returned for gpu_id
    if len(gpus) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID":
        raise EnvironmentError("Multi-gpu environment must set `CUDA_DEVICE_ORDER=PCI_BUS_ID`.")

    if gpu_id >= len(gpus) or gpu_id < 0:
        raise ValueError(f"Passed gpu_id:{gpu_id} but there are {len(gpus)} detected Nvidia gpus.")

    return gpus[gpu_id]

def find_best_match(tags, query):
    """
    Finds the best match for a query within a list of tags using fuzzy string matching.
    """
    MATCH_THRESHOLD = 25
    best_match, score = process.extractOne(query, tags)

    def check_target(best, default):
        return best if Target(best).arch == Target(default).arch else default

    if check_target(best_match, "cuda") == best_match:
        return best_match if score >= MATCH_THRESHOLD else "cuda"
    else:
        logger.warning(TARGET_MISSING_ERROR)
        return "cuda"


def get_all_nvidia_targets() -> List[str]:
    """
    Returns all available NVIDIA targets.
    """
    all_tags = list_tags()
    return [tag for tag in all_tags if "nvidia" in tag]


def auto_detect_nvidia_target(gpu_id: int = 0) -> str:
    """
    Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target.

    Returns:
        str: The detected TVM target architecture.
    """
    # Return a predefined target if specified in the environment variable
    # if "TVM_TARGET" in os.environ:
    #     return os.environ["TVM_TARGET"]

    # Fetch all available tags and filter for NVIDIA tags
    all_tags = list_tags()
    nvidia_tags = [tag for tag in all_tags if "nvidia" in tag]

    # Get the current GPU model and find the best matching target
    gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id)

    # Compat: remap oem devices to their correct non-oem model names for tvm target
    if gpu_model in NVIDIA_GPU_REMAP:
        gpu_model = NVIDIA_GPU_REMAP[gpu_model]

    target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda"
    return target