usiddiquee
hi
e1832f4
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
import os
import platform
import torch
from .. import __version__
from . import logger as LOGGER
from boxmot.utils import ROOT
def get_system_info():
return f"Yolo Tracking v{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__}"
def parse_device(device):
device = str(device).lower().replace("cuda:", "").replace("none", "").replace("(", "").replace(")", "").replace("[", "").replace("]", "").replace("'", "").replace(" ", "")
return device
def assert_cuda_available(device):
if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", ""))):
install = ("See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no CUDA devices are seen by torch.\n" if torch.cuda.device_count() == 0 else "")
raise ValueError(f"Invalid CUDA 'device={device}' requested. Use 'device=cpu' or pass valid CUDA device(s) if available, i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" +
f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" +
f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" +
f"\nos.environ['CUDA_VISIBLE_DEVICES']: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}\n{install}")
def select_device(device="", batch=0):
s = get_system_info()
device = parse_device(device)
mps = device == "mps"
cpu = device == "cpu" or device == "" and not torch.cuda.is_available()
if cpu or mps:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
elif device:
os.environ["CUDA_VISIBLE_DEVICES"] = device
assert_cuda_available(device)
if not cpu and not mps and torch.cuda.is_available():
devices = device.split(",") if device else ["0"]
n = len(devices)
if n > 1 and batch > 0 and batch % n != 0:
raise ValueError(f"'batch={batch}' must be a multiple of GPU count {n}.")
s += "\n" + "\n".join(f"CUDA:{d} ({torch.cuda.get_device_properties(i).name}, {torch.cuda.get_device_properties(i).total_memory / (1 << 20):.0f}MiB)" for i, d in enumerate(devices))
arg = "cuda:" + devices[0]
elif mps:
s += "MPS"
arg = "mps"
else:
s += "CPU"
arg = "cpu"
LOGGER.info(s)
return torch.device(arg)