| |
|
|
| import contextlib |
| import glob |
| import inspect |
| import math |
| import os |
| import platform |
| import re |
| import shutil |
| import subprocess |
| import time |
| from importlib import metadata |
| from pathlib import Path |
| from typing import Optional |
|
|
| import cv2 |
| import numpy as np |
| import requests |
| import torch |
| from matplotlib import font_manager |
|
|
| from doclayout_yolo.utils import ( |
| ASSETS, |
| AUTOINSTALL, |
| LINUX, |
| LOGGER, |
| ONLINE, |
| ROOT, |
| USER_CONFIG_DIR, |
| SimpleNamespace, |
| ThreadingLocked, |
| TryExcept, |
| clean_url, |
| colorstr, |
| downloads, |
| emojis, |
| is_colab, |
| is_docker, |
| is_github_action_running, |
| is_jupyter, |
| is_kaggle, |
| is_online, |
| is_pip_package, |
| url2file, |
| ) |
|
|
| PYTHON_VERSION = platform.python_version() |
|
|
|
|
| def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): |
| """ |
| Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. |
| |
| Args: |
| file_path (Path): Path to the requirements.txt file. |
| package (str, optional): Python package to use instead of requirements.txt file, i.e. package='doclayout_yolo'. |
| |
| Returns: |
| (List[Dict[str, str]]): List of parsed requirements as dictionaries with `name` and `specifier` keys. |
| |
| Example: |
| ```python |
| from doclayout_yolo.utils.checks import parse_requirements |
| |
| parse_requirements(package='doclayout_yolo') |
| ``` |
| """ |
|
|
| if package: |
| requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] |
| else: |
| requires = Path(file_path).read_text().splitlines() |
|
|
| requirements = [] |
| for line in requires: |
| line = line.strip() |
| if line and not line.startswith("#"): |
| line = line.split("#")[0].strip() |
| match = re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line) |
| if match: |
| requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) |
|
|
| return requirements |
|
|
|
|
| def parse_version(version="0.0.0") -> tuple: |
| """ |
| Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. This |
| function replaces deprecated 'pkg_resources.parse_version(v)'. |
| |
| Args: |
| version (str): Version string, i.e. '2.0.1+cpu' |
| |
| Returns: |
| (tuple): Tuple of integers representing the numeric part of the version and the extra string, i.e. (2, 0, 1) |
| """ |
| try: |
| return tuple(map(int, re.findall(r"\d+", version)[:3])) |
| except Exception as e: |
| LOGGER.warning(f"WARNING ⚠️ failure for parse_version({version}), returning (0, 0, 0): {e}") |
| return 0, 0, 0 |
|
|
|
|
| def is_ascii(s) -> bool: |
| """ |
| Check if a string is composed of only ASCII characters. |
| |
| Args: |
| s (str): String to be checked. |
| |
| Returns: |
| (bool): True if the string is composed only of ASCII characters, False otherwise. |
| """ |
| |
| s = str(s) |
|
|
| |
| return all(ord(c) < 128 for c in s) |
|
|
|
|
| def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): |
| """ |
| Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the |
| stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. |
| |
| Args: |
| imgsz (int | cList[int]): Image size. |
| stride (int): Stride value. |
| min_dim (int): Minimum number of dimensions. |
| max_dim (int): Maximum number of dimensions. |
| floor (int): Minimum allowed value for image size. |
| |
| Returns: |
| (List[int]): Updated image size. |
| """ |
| |
| stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) |
|
|
| |
| if isinstance(imgsz, int): |
| imgsz = [imgsz] |
| elif isinstance(imgsz, (list, tuple)): |
| imgsz = list(imgsz) |
| elif isinstance(imgsz, str): |
| imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) |
| else: |
| raise TypeError( |
| f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " |
| f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" |
| ) |
|
|
| |
| if len(imgsz) > max_dim: |
| msg = ( |
| "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " |
| "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" |
| ) |
| if max_dim != 1: |
| raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") |
| LOGGER.warning(f"WARNING ⚠️ updating to 'imgsz={max(imgsz)}'. {msg}") |
| imgsz = [max(imgsz)] |
| |
| sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] |
|
|
| |
| if sz != imgsz: |
| LOGGER.warning(f"WARNING ⚠️ imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") |
|
|
| |
| sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz |
|
|
| return sz |
|
|
|
|
| def check_version( |
| current: str = "0.0.0", |
| required: str = "0.0.0", |
| name: str = "version", |
| hard: bool = False, |
| verbose: bool = False, |
| msg: str = "", |
| ) -> bool: |
| """ |
| Check current version against the required version or range. |
| |
| Args: |
| current (str): Current version or package name to get version from. |
| required (str): Required version or range (in pip-style format). |
| name (str, optional): Name to be used in warning message. |
| hard (bool, optional): If True, raise an AssertionError if the requirement is not met. |
| verbose (bool, optional): If True, print warning message if requirement is not met. |
| msg (str, optional): Extra message to display if verbose. |
| |
| Returns: |
| (bool): True if requirement is met, False otherwise. |
| |
| Example: |
| ```python |
| # Check if current version is exactly 22.04 |
| check_version(current='22.04', required='==22.04') |
| |
| # Check if current version is greater than or equal to 22.04 |
| check_version(current='22.10', required='22.04') # assumes '>=' inequality if none passed |
| |
| # Check if current version is less than or equal to 22.04 |
| check_version(current='22.04', required='<=22.04') |
| |
| # Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) |
| check_version(current='21.10', required='>20.04,<22.04') |
| ``` |
| """ |
| if not current: |
| LOGGER.warning(f"WARNING ⚠️ invalid check_version({current}, {required}) requested, please check values.") |
| return True |
| elif not current[0].isdigit(): |
| try: |
| name = current |
| current = metadata.version(current) |
| except metadata.PackageNotFoundError as e: |
| if hard: |
| raise ModuleNotFoundError(emojis(f"WARNING ⚠️ {current} package is required but not installed")) from e |
| else: |
| return False |
|
|
| if not required: |
| return True |
|
|
| op = "" |
| version = "" |
| result = True |
| c = parse_version(current) |
| for r in required.strip(",").split(","): |
| op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() |
| v = parse_version(version) |
| if op == "==" and c != v: |
| result = False |
| elif op == "!=" and c == v: |
| result = False |
| elif op in (">=", "") and not (c >= v): |
| result = False |
| elif op == "<=" and not (c <= v): |
| result = False |
| elif op == ">" and not (c > v): |
| result = False |
| elif op == "<" and not (c < v): |
| result = False |
| if not result: |
| warning = f"WARNING ⚠️ {name}{op}{version} is required, but {name}=={current} is currently installed {msg}" |
| if hard: |
| raise ModuleNotFoundError(emojis(warning)) |
| if verbose: |
| LOGGER.warning(warning) |
| return result |
|
|
|
|
| def check_latest_pypi_version(package_name="doclayout_yolo"): |
| """ |
| Returns the latest version of a PyPI package without downloading or installing it. |
| |
| Parameters: |
| package_name (str): The name of the package to find the latest version for. |
| |
| Returns: |
| (str): The latest version of the package. |
| """ |
| with contextlib.suppress(Exception): |
| requests.packages.urllib3.disable_warnings() |
| response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) |
| if response.status_code == 200: |
| return response.json()["info"]["version"] |
|
|
|
|
| def check_pip_update_available(): |
| """ |
| Checks if a new version of the doclayout_yolo package is available on PyPI. |
| |
| Returns: |
| (bool): True if an update is available, False otherwise. |
| """ |
| if ONLINE and is_pip_package(): |
| with contextlib.suppress(Exception): |
| from doclayout_yolo import __version__ |
|
|
| latest = check_latest_pypi_version() |
| if check_version(__version__, f"<{latest}"): |
| LOGGER.info( |
| f"New https://pypi.org/project/doclayout_yolo/{latest} available 😃 " |
| f"Update with 'pip install -U doclayout_yolo'" |
| ) |
| return True |
| return False |
|
|
|
|
| @ThreadingLocked() |
| def check_font(font="Arial.ttf"): |
| """ |
| Find font locally or download to user's configuration directory if it does not already exist. |
| |
| Args: |
| font (str): Path or name of font. |
| |
| Returns: |
| file (Path): Resolved font file path. |
| """ |
| name = Path(font).name |
|
|
| |
| file = USER_CONFIG_DIR / name |
| if file.exists(): |
| return file |
|
|
| |
| matches = [s for s in font_manager.findSystemFonts() if font in s] |
| if any(matches): |
| return matches[0] |
|
|
| |
| url = f"https://doclayout_yolo.com/assets/{name}" |
| if downloads.is_url(url, check=True): |
| downloads.safe_download(url=url, file=file) |
| return file |
|
|
|
|
| def check_python(minimum: str = "3.8.0") -> bool: |
| """ |
| Check current python version against the required minimum version. |
| |
| Args: |
| minimum (str): Required minimum version of python. |
| |
| Returns: |
| (bool): Whether the installed Python version meets the minimum constraints. |
| """ |
| return check_version(PYTHON_VERSION, minimum, name="Python ", hard=True) |
|
|
|
|
| @TryExcept() |
| def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): |
| """ |
| Check if installed dependencies meet YOLOv8 requirements and attempt to auto-update if needed. |
| |
| Args: |
| requirements (Union[Path, str, List[str]]): Path to a requirements.txt file, a single package requirement as a |
| string, or a list of package requirements as strings. |
| exclude (Tuple[str]): Tuple of package names to exclude from checking. |
| install (bool): If True, attempt to auto-update packages that don't meet requirements. |
| cmds (str): Additional commands to pass to the pip install command when auto-updating. |
| |
| Example: |
| ```python |
| from doclayout_yolo.utils.checks import check_requirements |
| |
| # Check a requirements.txt file |
| check_requirements('path/to/requirements.txt') |
| |
| # Check a single package |
| check_requirements('doclayout_yolo>=8.0.0') |
| |
| # Check multiple packages |
| check_requirements(['numpy', 'doclayout_yolo>=8.0.0']) |
| ``` |
| """ |
|
|
| prefix = colorstr("red", "bold", "requirements:") |
| check_python() |
| check_torchvision() |
| if isinstance(requirements, Path): |
| file = requirements.resolve() |
| assert file.exists(), f"{prefix} {file} not found, check failed." |
| requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] |
| elif isinstance(requirements, str): |
| requirements = [requirements] |
|
|
| pkgs = [] |
| for r in requirements: |
| r_stripped = r.split("/")[-1].replace(".git", "") |
| match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) |
| name, required = match[1], match[2].strip() if match[2] else "" |
| try: |
| assert check_version(metadata.version(name), required) |
| except (AssertionError, metadata.PackageNotFoundError): |
| pkgs.append(r) |
|
|
| s = " ".join(f'"{x}"' for x in pkgs) |
| if s: |
| if install and AUTOINSTALL: |
| n = len(pkgs) |
| LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") |
| try: |
| t = time.time() |
| assert is_online(), "AutoUpdate skipped (offline)" |
| LOGGER.info(subprocess.check_output(f"pip install --no-cache {s} {cmds}", shell=True).decode()) |
| dt = time.time() - t |
| LOGGER.info( |
| f"{prefix} AutoUpdate success ✅ {dt:.1f}s, installed {n} package{'s' * (n > 1)}: {pkgs}\n" |
| f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" |
| ) |
| except Exception as e: |
| LOGGER.warning(f"{prefix} ❌ {e}") |
| return False |
| else: |
| return False |
|
|
| return True |
|
|
|
|
| def check_torchvision(): |
| """ |
| Checks the installed versions of PyTorch and Torchvision to ensure they're compatible. |
| |
| This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according |
| to the provided compatibility table based on: |
| https://github.com/pytorch/vision#installation. |
| |
| The compatibility table is a dictionary where the keys are PyTorch versions and the values are lists of compatible |
| Torchvision versions. |
| """ |
|
|
| import torchvision |
|
|
| |
| compatibility_table = {"2.0": ["0.15"], "1.13": ["0.14"], "1.12": ["0.13"]} |
|
|
| |
| v_torch = ".".join(torch.__version__.split("+")[0].split(".")[:2]) |
| v_torchvision = ".".join(torchvision.__version__.split("+")[0].split(".")[:2]) |
|
|
| if v_torch in compatibility_table: |
| compatible_versions = compatibility_table[v_torch] |
| if all(v_torchvision != v for v in compatible_versions): |
| print( |
| f"WARNING ⚠️ torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" |
| f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " |
| "'pip install -U torch torchvision' to update both.\n" |
| "For a full compatibility table see https://github.com/pytorch/vision#installation" |
| ) |
|
|
|
|
| def check_suffix(file="yolov8n.pt", suffix=".pt", msg=""): |
| """Check file(s) for acceptable suffix.""" |
| if file and suffix: |
| if isinstance(suffix, str): |
| suffix = (suffix,) |
| for f in file if isinstance(file, (list, tuple)) else [file]: |
| s = Path(f).suffix.lower().strip() |
| if len(s): |
| assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}, not {s}" |
|
|
|
|
| def check_yolov5u_filename(file: str, verbose: bool = True): |
| """Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.""" |
| if "yolov3" in file or "yolov5" in file: |
| if "u.yaml" in file: |
| file = file.replace("u.yaml", ".yaml") |
| elif ".pt" in file and "u" not in file: |
| original_file = file |
| file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) |
| file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) |
| file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) |
| if file != original_file and verbose: |
| LOGGER.info( |
| f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " |
| f"trained with https://github.com/doclayout_yolo/doclayout_yolo and feature improved performance vs " |
| f"standard YOLOv5 models trained with https://github.com/doclayout_yolo/yolov5.\n" |
| ) |
| return file |
|
|
|
|
| def check_model_file_from_stem(model="yolov8n"): |
| """Return a model filename from a valid model stem.""" |
| if model and not Path(model).suffix and Path(model).stem in downloads.GITHUB_ASSETS_STEMS: |
| return Path(model).with_suffix(".pt") |
| else: |
| return model |
|
|
|
|
| def check_file(file, suffix="", download=True, hard=True): |
| """Search/download file (if necessary) and return path.""" |
| check_suffix(file, suffix) |
| file = str(file).strip() |
| file = check_yolov5u_filename(file) |
| if ( |
| not file |
| or ("://" not in file and Path(file).exists()) |
| or file.lower().startswith("grpc://") |
| ): |
| return file |
| elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): |
| url = file |
| file = url2file(file) |
| if Path(file).exists(): |
| LOGGER.info(f"Found {clean_url(url)} locally at {file}") |
| else: |
| downloads.safe_download(url=url, file=file, unzip=False) |
| return file |
| else: |
| files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) |
| if not files and hard: |
| raise FileNotFoundError(f"'{file}' does not exist") |
| elif len(files) > 1 and hard: |
| raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") |
| return files[0] if len(files) else [] if hard else file |
|
|
|
|
| def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): |
| """Search/download YAML file (if necessary) and return path, checking suffix.""" |
| return check_file(file, suffix, hard=hard) |
|
|
|
|
| def check_is_path_safe(basedir, path): |
| """ |
| Check if the resolved path is under the intended directory to prevent path traversal. |
| |
| Args: |
| basedir (Path | str): The intended directory. |
| path (Path | str): The path to check. |
| |
| Returns: |
| (bool): True if the path is safe, False otherwise. |
| """ |
| base_dir_resolved = Path(basedir).resolve() |
| path_resolved = Path(path).resolve() |
|
|
| return path_resolved.is_file() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts |
|
|
|
|
| def check_imshow(warn=False): |
| """Check if environment supports image displays.""" |
| try: |
| if LINUX: |
| assert "DISPLAY" in os.environ and not is_docker() and not is_colab() and not is_kaggle() |
| cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) |
| cv2.waitKey(1) |
| cv2.destroyAllWindows() |
| cv2.waitKey(1) |
| return True |
| except Exception as e: |
| if warn: |
| LOGGER.warning(f"WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}") |
| return False |
|
|
|
|
| def check_yolo(verbose=True, device=""): |
| """Return a human-readable YOLO software and hardware summary.""" |
| import psutil |
|
|
| from doclayout_yolo.utils.torch_utils import select_device |
|
|
| if is_jupyter(): |
| if check_requirements("wandb", install=False): |
| os.system("pip uninstall -y wandb") |
| if is_colab(): |
| shutil.rmtree("sample_data", ignore_errors=True) |
|
|
| if verbose: |
| |
| gib = 1 << 30 |
| ram = psutil.virtual_memory().total |
| total, used, free = shutil.disk_usage("/") |
| s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" |
| with contextlib.suppress(Exception): |
| from IPython import display |
|
|
| display.clear_output() |
| else: |
| s = "" |
|
|
| select_device(device=device, newline=False) |
| LOGGER.info(f"Setup complete ✅ {s}") |
|
|
|
|
| def collect_system_info(): |
| """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" |
|
|
| import psutil |
|
|
| from doclayout_yolo.utils import ENVIRONMENT, is_git_dir |
| from doclayout_yolo.utils.torch_utils import get_cpu_info |
|
|
| ram_info = psutil.virtual_memory().total / (1024**3) |
| check_yolo() |
| LOGGER.info( |
| f"\n{'OS':<20}{platform.platform()}\n" |
| f"{'Environment':<20}{ENVIRONMENT}\n" |
| f"{'Python':<20}{PYTHON_VERSION}\n" |
| f"{'Install':<20}{'git' if is_git_dir() else 'pip' if is_pip_package() else 'other'}\n" |
| f"{'RAM':<20}{ram_info:.2f} GB\n" |
| f"{'CPU':<20}{get_cpu_info()}\n" |
| f"{'CUDA':<20}{torch.version.cuda if torch and torch.cuda.is_available() else None}\n" |
| ) |
|
|
| for r in parse_requirements(package="doclayout_yolo"): |
| try: |
| current = metadata.version(r.name) |
| is_met = "✅ " if check_version(current, str(r.specifier), hard=True) else "❌ " |
| except metadata.PackageNotFoundError: |
| current = "(not installed)" |
| is_met = "❌ " |
| LOGGER.info(f"{r.name:<20}{is_met}{current}{r.specifier}") |
|
|
| if is_github_action_running(): |
| LOGGER.info( |
| f"\nRUNNER_OS: {os.getenv('RUNNER_OS')}\n" |
| f"GITHUB_EVENT_NAME: {os.getenv('GITHUB_EVENT_NAME')}\n" |
| f"GITHUB_WORKFLOW: {os.getenv('GITHUB_WORKFLOW')}\n" |
| f"GITHUB_ACTOR: {os.getenv('GITHUB_ACTOR')}\n" |
| f"GITHUB_REPOSITORY: {os.getenv('GITHUB_REPOSITORY')}\n" |
| f"GITHUB_REPOSITORY_OWNER: {os.getenv('GITHUB_REPOSITORY_OWNER')}\n" |
| ) |
|
|
|
|
| def check_amp(model): |
| """ |
| This function checks the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLOv8 model. If the checks |
| fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP results, so AMP will |
| be disabled during training. |
| |
| Args: |
| model (nn.Module): A YOLOv8 model instance. |
| |
| Example: |
| ```python |
| from doclayout_yolo import YOLO |
| from doclayout_yolo.utils.checks import check_amp |
| |
| model = YOLO('yolov8n.pt').model.cuda() |
| check_amp(model) |
| ``` |
| |
| Returns: |
| (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False. |
| """ |
| device = next(model.parameters()).device |
| if device.type in ("cpu", "mps"): |
| return False |
|
|
| def amp_allclose(m, im): |
| """All close FP32 vs AMP results.""" |
| a = m(im, device=device, verbose=False)[0].boxes.data |
| with torch.cuda.amp.autocast(True): |
| b = m(im, device=device, verbose=False)[0].boxes.data |
| del m |
| return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) |
|
|
| im = ASSETS / "bus.jpg" |
| prefix = colorstr("AMP: ") |
| LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...") |
| warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." |
| try: |
| from doclayout_yolo import YOLO |
|
|
| assert amp_allclose(YOLO("yolov8n.pt"), im) |
| LOGGER.info(f"{prefix}checks passed ✅") |
| except ConnectionError: |
| LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}") |
| except (AttributeError, ModuleNotFoundError): |
| LOGGER.warning( |
| f"{prefix}checks skipped ⚠️. " |
| f"Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}" |
| ) |
| except AssertionError: |
| LOGGER.warning( |
| f"{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to " |
| f"NaN losses or zero-mAP results, so AMP will be disabled during training." |
| ) |
| return False |
| return True |
|
|
|
|
| def git_describe(path=ROOT): |
| """Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe.""" |
| with contextlib.suppress(Exception): |
| return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] |
| return "" |
|
|
|
|
| def print_args(args: Optional[dict] = None, show_file=True, show_func=False): |
| """Print function arguments (optional args dict).""" |
|
|
| def strip_auth(v): |
| """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" |
| return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v |
|
|
| x = inspect.currentframe().f_back |
| file, _, func, _, _ = inspect.getframeinfo(x) |
| if args is None: |
| args, _, _, frm = inspect.getargvalues(x) |
| args = {k: v for k, v in frm.items() if k in args} |
| try: |
| file = Path(file).resolve().relative_to(ROOT).with_suffix("") |
| except ValueError: |
| file = Path(file).stem |
| s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") |
| LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in args.items())) |
|
|
|
|
| def cuda_device_count() -> int: |
| """ |
| Get the number of NVIDIA GPUs available in the environment. |
| |
| Returns: |
| (int): The number of NVIDIA GPUs available. |
| """ |
| try: |
| |
| output = subprocess.check_output( |
| ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" |
| ) |
|
|
| |
| first_line = output.strip().split("\n")[0] |
|
|
| return int(first_line) |
| except (subprocess.CalledProcessError, FileNotFoundError, ValueError): |
| |
| return 0 |
|
|
|
|
| def cuda_is_available() -> bool: |
| """ |
| Check if CUDA is available in the environment. |
| |
| Returns: |
| (bool): True if one or more NVIDIA GPUs are available, False otherwise. |
| """ |
| return cuda_device_count() > 0 |
|
|
|
|
| |
| IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") |
|
|