File size: 1,597 Bytes
7344bef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import annotations

import os
import sys

from collections.abc import MutableSequence


def _arg_name_to_option(arg_name: str) -> str:
    arg_name = str(arg_name or "").strip()
    if not arg_name:
        return ""
    return arg_name if arg_name.startswith("--") else f"--{arg_name}"


def _cuda_visible_device(device: str) -> str:
    device = str(device or "").strip().lower()
    if device.startswith("cuda:"):
        device = device.split(":", 1)[1]
    return device if device.isdigit() else ""


def _rewrite_arg_value(argv: MutableSequence[str], option: str, value: str) -> None:
    for index, arg in enumerate(argv):
        if arg == option and index + 1 < len(argv):
            argv[index + 1] = value
            return
        if str(arg).startswith(f"{option}="):
            argv[index] = f"{option}={value}"
            return


def set_default_cuda_device_from_arg(arg_name: str, default_device: str = "cuda:0") -> bool:
    option = _arg_name_to_option(arg_name)
    if not option:
        return False
    argv = sys.argv
    for index, arg in enumerate(argv[1:], start=1):
        if arg == option and index + 1 < len(argv):
            visible_device = _cuda_visible_device(argv[index + 1])
            break
        if str(arg).startswith(f"{option}="):
            visible_device = _cuda_visible_device(str(arg).split("=", 1)[1])
            break
    else:
        return False

    if not visible_device:
        return False
    os.environ["CUDA_VISIBLE_DEVICES"] = visible_device
    _rewrite_arg_value(argv, option, default_device)
    return True