Spaces:
Build error
Build error
Upload 3 files
Browse files- general.py +1135 -0
- plots.py +570 -0
- torch_utils.py +529 -0
general.py
ADDED
|
@@ -0,0 +1,1135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import glob
|
| 3 |
+
import inspect
|
| 4 |
+
import logging
|
| 5 |
+
import logging.config
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import platform
|
| 9 |
+
import random
|
| 10 |
+
import re
|
| 11 |
+
import signal
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import urllib
|
| 15 |
+
from copy import deepcopy
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from itertools import repeat
|
| 18 |
+
from multiprocessing.pool import ThreadPool
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from subprocess import check_output
|
| 21 |
+
from tarfile import is_tarfile
|
| 22 |
+
from typing import Optional
|
| 23 |
+
from zipfile import ZipFile, is_zipfile
|
| 24 |
+
|
| 25 |
+
import cv2
|
| 26 |
+
import IPython
|
| 27 |
+
import numpy as np
|
| 28 |
+
import pandas as pd
|
| 29 |
+
import pkg_resources as pkg
|
| 30 |
+
import torch
|
| 31 |
+
import torchvision
|
| 32 |
+
import yaml
|
| 33 |
+
|
| 34 |
+
from utils import TryExcept, emojis
|
| 35 |
+
from utils.downloads import gsutil_getsize
|
| 36 |
+
from utils.metrics import box_iou, fitness
|
| 37 |
+
|
| 38 |
+
FILE = Path(__file__).resolve()
|
| 39 |
+
ROOT = FILE.parents[1] # YOLO root directory
|
| 40 |
+
RANK = int(os.getenv('RANK', -1))
|
| 41 |
+
|
| 42 |
+
# Settings
|
| 43 |
+
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
|
| 44 |
+
DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
|
| 45 |
+
AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
|
| 46 |
+
VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
|
| 47 |
+
TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
|
| 48 |
+
FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
|
| 49 |
+
|
| 50 |
+
torch.set_printoptions(linewidth=320, precision=5, profile='long')
|
| 51 |
+
np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
|
| 52 |
+
pd.options.display.max_columns = 10
|
| 53 |
+
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
| 54 |
+
os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
|
| 55 |
+
os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def is_ascii(s=''):
|
| 59 |
+
# Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
|
| 60 |
+
s = str(s) # convert list, tuple, None, etc. to str
|
| 61 |
+
return len(s.encode().decode('ascii', 'ignore')) == len(s)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def is_chinese(s='人工智能'):
|
| 65 |
+
# Is string composed of any Chinese characters?
|
| 66 |
+
return bool(re.search('[\u4e00-\u9fff]', str(s)))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def is_colab():
|
| 70 |
+
# Is environment a Google Colab instance?
|
| 71 |
+
return 'google.colab' in sys.modules
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def is_notebook():
|
| 75 |
+
# Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
|
| 76 |
+
ipython_type = str(type(IPython.get_ipython()))
|
| 77 |
+
return 'colab' in ipython_type or 'zmqshell' in ipython_type
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def is_kaggle():
|
| 81 |
+
# Is environment a Kaggle Notebook?
|
| 82 |
+
return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def is_docker() -> bool:
|
| 86 |
+
"""Check if the process runs inside a docker container."""
|
| 87 |
+
if Path("/.dockerenv").exists():
|
| 88 |
+
return True
|
| 89 |
+
try: # check if docker is in control groups
|
| 90 |
+
with open("/proc/self/cgroup") as file:
|
| 91 |
+
return any("docker" in line for line in file)
|
| 92 |
+
except OSError:
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def is_writeable(dir, test=False):
|
| 97 |
+
# Return True if directory has write permissions, test opening a file with write permissions if test=True
|
| 98 |
+
if not test:
|
| 99 |
+
return os.access(dir, os.W_OK) # possible issues on Windows
|
| 100 |
+
file = Path(dir) / 'tmp.txt'
|
| 101 |
+
try:
|
| 102 |
+
with open(file, 'w'): # open file with write permissions
|
| 103 |
+
pass
|
| 104 |
+
file.unlink() # remove file
|
| 105 |
+
return True
|
| 106 |
+
except OSError:
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
LOGGING_NAME = "yolov5"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def set_logging(name=LOGGING_NAME, verbose=True):
|
| 114 |
+
# sets up logging for the given name
|
| 115 |
+
rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
|
| 116 |
+
level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
|
| 117 |
+
logging.config.dictConfig({
|
| 118 |
+
"version": 1,
|
| 119 |
+
"disable_existing_loggers": False,
|
| 120 |
+
"formatters": {
|
| 121 |
+
name: {
|
| 122 |
+
"format": "%(message)s"}},
|
| 123 |
+
"handlers": {
|
| 124 |
+
name: {
|
| 125 |
+
"class": "logging.StreamHandler",
|
| 126 |
+
"formatter": name,
|
| 127 |
+
"level": level,}},
|
| 128 |
+
"loggers": {
|
| 129 |
+
name: {
|
| 130 |
+
"level": level,
|
| 131 |
+
"handlers": [name],
|
| 132 |
+
"propagate": False,}}})
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
set_logging(LOGGING_NAME) # run before defining LOGGER
|
| 136 |
+
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
|
| 137 |
+
if platform.system() == 'Windows':
|
| 138 |
+
for fn in LOGGER.info, LOGGER.warning:
|
| 139 |
+
setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
|
| 143 |
+
# Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
|
| 144 |
+
env = os.getenv(env_var)
|
| 145 |
+
if env:
|
| 146 |
+
path = Path(env) # use environment variable
|
| 147 |
+
else:
|
| 148 |
+
cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
|
| 149 |
+
path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
|
| 150 |
+
path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
|
| 151 |
+
path.mkdir(exist_ok=True) # make if required
|
| 152 |
+
return path
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
CONFIG_DIR = user_config_dir() # Ultralytics settings dir
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class Profile(contextlib.ContextDecorator):
|
| 159 |
+
# YOLO Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
|
| 160 |
+
def __init__(self, t=0.0):
|
| 161 |
+
self.t = t
|
| 162 |
+
self.cuda = torch.cuda.is_available()
|
| 163 |
+
|
| 164 |
+
def __enter__(self):
|
| 165 |
+
self.start = self.time()
|
| 166 |
+
return self
|
| 167 |
+
|
| 168 |
+
def __exit__(self, type, value, traceback):
|
| 169 |
+
self.dt = self.time() - self.start # delta-time
|
| 170 |
+
self.t += self.dt # accumulate dt
|
| 171 |
+
|
| 172 |
+
def time(self):
|
| 173 |
+
if self.cuda:
|
| 174 |
+
torch.cuda.synchronize()
|
| 175 |
+
return time.time()
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class Timeout(contextlib.ContextDecorator):
|
| 179 |
+
# YOLO Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
|
| 180 |
+
def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
|
| 181 |
+
self.seconds = int(seconds)
|
| 182 |
+
self.timeout_message = timeout_msg
|
| 183 |
+
self.suppress = bool(suppress_timeout_errors)
|
| 184 |
+
|
| 185 |
+
def _timeout_handler(self, signum, frame):
|
| 186 |
+
raise TimeoutError(self.timeout_message)
|
| 187 |
+
|
| 188 |
+
def __enter__(self):
|
| 189 |
+
if platform.system() != 'Windows': # not supported on Windows
|
| 190 |
+
signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
|
| 191 |
+
signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
|
| 192 |
+
|
| 193 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 194 |
+
if platform.system() != 'Windows':
|
| 195 |
+
signal.alarm(0) # Cancel SIGALRM if it's scheduled
|
| 196 |
+
if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
|
| 197 |
+
return True
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class WorkingDirectory(contextlib.ContextDecorator):
|
| 201 |
+
# Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
|
| 202 |
+
def __init__(self, new_dir):
|
| 203 |
+
self.dir = new_dir # new dir
|
| 204 |
+
self.cwd = Path.cwd().resolve() # current dir
|
| 205 |
+
|
| 206 |
+
def __enter__(self):
|
| 207 |
+
os.chdir(self.dir)
|
| 208 |
+
|
| 209 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 210 |
+
os.chdir(self.cwd)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def methods(instance):
|
| 214 |
+
# Get class/instance methods
|
| 215 |
+
return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
| 219 |
+
# Print function arguments (optional args dict)
|
| 220 |
+
x = inspect.currentframe().f_back # previous frame
|
| 221 |
+
file, _, func, _, _ = inspect.getframeinfo(x)
|
| 222 |
+
if args is None: # get args automatically
|
| 223 |
+
args, _, _, frm = inspect.getargvalues(x)
|
| 224 |
+
args = {k: v for k, v in frm.items() if k in args}
|
| 225 |
+
try:
|
| 226 |
+
file = Path(file).resolve().relative_to(ROOT).with_suffix('')
|
| 227 |
+
except ValueError:
|
| 228 |
+
file = Path(file).stem
|
| 229 |
+
s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
|
| 230 |
+
LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def init_seeds(seed=0, deterministic=False):
|
| 234 |
+
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
|
| 235 |
+
random.seed(seed)
|
| 236 |
+
np.random.seed(seed)
|
| 237 |
+
torch.manual_seed(seed)
|
| 238 |
+
torch.cuda.manual_seed(seed)
|
| 239 |
+
torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
|
| 240 |
+
# torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
|
| 241 |
+
if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
|
| 242 |
+
torch.use_deterministic_algorithms(True)
|
| 243 |
+
torch.backends.cudnn.deterministic = True
|
| 244 |
+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
| 245 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def intersect_dicts(da, db, exclude=()):
|
| 249 |
+
# Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
|
| 250 |
+
return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def get_default_args(func):
|
| 254 |
+
# Get func() default arguments
|
| 255 |
+
signature = inspect.signature(func)
|
| 256 |
+
return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_latest_run(search_dir='.'):
|
| 260 |
+
# Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
|
| 261 |
+
last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
|
| 262 |
+
return max(last_list, key=os.path.getctime) if last_list else ''
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def file_age(path=__file__):
|
| 266 |
+
# Return days since last file update
|
| 267 |
+
dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
|
| 268 |
+
return dt.days # + dt.seconds / 86400 # fractional days
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def file_date(path=__file__):
|
| 272 |
+
# Return human-readable file modification date, i.e. '2021-3-26'
|
| 273 |
+
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
| 274 |
+
return f'{t.year}-{t.month}-{t.day}'
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def file_size(path):
|
| 278 |
+
# Return file/dir size (MB)
|
| 279 |
+
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
| 280 |
+
path = Path(path)
|
| 281 |
+
if path.is_file():
|
| 282 |
+
return path.stat().st_size / mb
|
| 283 |
+
elif path.is_dir():
|
| 284 |
+
return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
|
| 285 |
+
else:
|
| 286 |
+
return 0.0
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def check_online():
|
| 290 |
+
# Check internet connectivity
|
| 291 |
+
import socket
|
| 292 |
+
|
| 293 |
+
def run_once():
|
| 294 |
+
# Check once
|
| 295 |
+
try:
|
| 296 |
+
socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
|
| 297 |
+
return True
|
| 298 |
+
except OSError:
|
| 299 |
+
return False
|
| 300 |
+
|
| 301 |
+
return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def git_describe(path=ROOT): # path must be a directory
|
| 305 |
+
# Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
|
| 306 |
+
try:
|
| 307 |
+
assert (Path(path) / '.git').is_dir()
|
| 308 |
+
return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
|
| 309 |
+
except Exception:
|
| 310 |
+
return ''
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
@TryExcept()
|
| 314 |
+
@WorkingDirectory(ROOT)
|
| 315 |
+
def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
|
| 316 |
+
# YOLO status check, recommend 'git pull' if code is out of date
|
| 317 |
+
url = f'https://github.com/{repo}'
|
| 318 |
+
msg = f', for updates see {url}'
|
| 319 |
+
s = colorstr('github: ') # string
|
| 320 |
+
assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
|
| 321 |
+
assert check_online(), s + 'skipping check (offline)' + msg
|
| 322 |
+
|
| 323 |
+
splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
|
| 324 |
+
matches = [repo in s for s in splits]
|
| 325 |
+
if any(matches):
|
| 326 |
+
remote = splits[matches.index(True) - 1]
|
| 327 |
+
else:
|
| 328 |
+
remote = 'ultralytics'
|
| 329 |
+
check_output(f'git remote add {remote} {url}', shell=True)
|
| 330 |
+
check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
|
| 331 |
+
local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
|
| 332 |
+
n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
|
| 333 |
+
if n > 0:
|
| 334 |
+
pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
|
| 335 |
+
s += f"⚠️ YOLO is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
|
| 336 |
+
else:
|
| 337 |
+
s += f'up to date with {url} ✅'
|
| 338 |
+
LOGGER.info(s)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@WorkingDirectory(ROOT)
|
| 342 |
+
def check_git_info(path='.'):
|
| 343 |
+
# YOLO git info check, return {remote, branch, commit}
|
| 344 |
+
check_requirements('gitpython')
|
| 345 |
+
import git
|
| 346 |
+
try:
|
| 347 |
+
repo = git.Repo(path)
|
| 348 |
+
remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/WongKinYiu/yolov9'
|
| 349 |
+
commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
|
| 350 |
+
try:
|
| 351 |
+
branch = repo.active_branch.name # i.e. 'main'
|
| 352 |
+
except TypeError: # not on any branch
|
| 353 |
+
branch = None # i.e. 'detached HEAD' state
|
| 354 |
+
return {'remote': remote, 'branch': branch, 'commit': commit}
|
| 355 |
+
except git.exc.InvalidGitRepositoryError: # path is not a git dir
|
| 356 |
+
return {'remote': None, 'branch': None, 'commit': None}
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def check_python(minimum='3.7.0'):
|
| 360 |
+
# Check current python version vs. required python version
|
| 361 |
+
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
|
| 365 |
+
# Check version vs. required version
|
| 366 |
+
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
| 367 |
+
result = (current == minimum) if pinned else (current >= minimum) # bool
|
| 368 |
+
s = f'WARNING ⚠️ {name}{minimum} is required by YOLO, but {name}{current} is currently installed' # string
|
| 369 |
+
if hard:
|
| 370 |
+
assert result, emojis(s) # assert min requirements met
|
| 371 |
+
if verbose and not result:
|
| 372 |
+
LOGGER.warning(s)
|
| 373 |
+
return result
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@TryExcept()
|
| 377 |
+
def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
|
| 378 |
+
# Check installed dependencies meet YOLO requirements (pass *.txt file or list of packages or single package str)
|
| 379 |
+
prefix = colorstr('red', 'bold', 'requirements:')
|
| 380 |
+
check_python() # check python version
|
| 381 |
+
if isinstance(requirements, Path): # requirements.txt file
|
| 382 |
+
file = requirements.resolve()
|
| 383 |
+
assert file.exists(), f"{prefix} {file} not found, check failed."
|
| 384 |
+
with file.open() as f:
|
| 385 |
+
requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
|
| 386 |
+
elif isinstance(requirements, str):
|
| 387 |
+
requirements = [requirements]
|
| 388 |
+
|
| 389 |
+
s = ''
|
| 390 |
+
n = 0
|
| 391 |
+
for r in requirements:
|
| 392 |
+
try:
|
| 393 |
+
pkg.require(r)
|
| 394 |
+
except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
|
| 395 |
+
s += f'"{r}" '
|
| 396 |
+
n += 1
|
| 397 |
+
|
| 398 |
+
if s and install and AUTOINSTALL: # check environment variable
|
| 399 |
+
LOGGER.info(f"{prefix} YOLO requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
|
| 400 |
+
try:
|
| 401 |
+
# assert check_online(), "AutoUpdate skipped (offline)"
|
| 402 |
+
LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
|
| 403 |
+
source = file if 'file' in locals() else requirements
|
| 404 |
+
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
|
| 405 |
+
f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
| 406 |
+
LOGGER.info(s)
|
| 407 |
+
except Exception as e:
|
| 408 |
+
LOGGER.warning(f'{prefix} ❌ {e}')
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def check_img_size(imgsz, s=32, floor=0):
|
| 412 |
+
# Verify image size is a multiple of stride s in each dimension
|
| 413 |
+
if isinstance(imgsz, int): # integer i.e. img_size=640
|
| 414 |
+
new_size = max(make_divisible(imgsz, int(s)), floor)
|
| 415 |
+
else: # list i.e. img_size=[640, 480]
|
| 416 |
+
imgsz = list(imgsz) # convert to list if tuple
|
| 417 |
+
new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
|
| 418 |
+
if new_size != imgsz:
|
| 419 |
+
LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
|
| 420 |
+
return new_size
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def check_imshow(warn=False):
|
| 424 |
+
# Check if environment supports image displays
|
| 425 |
+
try:
|
| 426 |
+
assert not is_notebook()
|
| 427 |
+
assert not is_docker()
|
| 428 |
+
cv2.imshow('test', np.zeros((1, 1, 3)))
|
| 429 |
+
cv2.waitKey(1)
|
| 430 |
+
cv2.destroyAllWindows()
|
| 431 |
+
cv2.waitKey(1)
|
| 432 |
+
return True
|
| 433 |
+
except Exception as e:
|
| 434 |
+
if warn:
|
| 435 |
+
LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
|
| 436 |
+
return False
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
|
| 440 |
+
# Check file(s) for acceptable suffix
|
| 441 |
+
if file and suffix:
|
| 442 |
+
if isinstance(suffix, str):
|
| 443 |
+
suffix = [suffix]
|
| 444 |
+
for f in file if isinstance(file, (list, tuple)) else [file]:
|
| 445 |
+
s = Path(f).suffix.lower() # file suffix
|
| 446 |
+
if len(s):
|
| 447 |
+
assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def check_yaml(file, suffix=('.yaml', '.yml')):
|
| 451 |
+
# Search/download YAML file (if necessary) and return path, checking suffix
|
| 452 |
+
return check_file(file, suffix)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def check_file(file, suffix=''):
|
| 456 |
+
# Search/download file (if necessary) and return path
|
| 457 |
+
check_suffix(file, suffix) # optional
|
| 458 |
+
file = str(file) # convert to str()
|
| 459 |
+
if os.path.isfile(file) or not file: # exists
|
| 460 |
+
return file
|
| 461 |
+
elif file.startswith(('http:/', 'https:/')): # download
|
| 462 |
+
url = file # warning: Pathlib turns :// -> :/
|
| 463 |
+
file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
|
| 464 |
+
if os.path.isfile(file):
|
| 465 |
+
LOGGER.info(f'Found {url} locally at {file}') # file already exists
|
| 466 |
+
else:
|
| 467 |
+
LOGGER.info(f'Downloading {url} to {file}...')
|
| 468 |
+
torch.hub.download_url_to_file(url, file)
|
| 469 |
+
assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
|
| 470 |
+
return file
|
| 471 |
+
elif file.startswith('clearml://'): # ClearML Dataset ID
|
| 472 |
+
assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
|
| 473 |
+
return file
|
| 474 |
+
else: # search
|
| 475 |
+
files = []
|
| 476 |
+
for d in 'data', 'models', 'utils': # search directories
|
| 477 |
+
files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
|
| 478 |
+
assert len(files), f'File not found: {file}' # assert file was found
|
| 479 |
+
assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
|
| 480 |
+
return files[0] # return file
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def check_font(font=FONT, progress=False):
|
| 484 |
+
# Download font to CONFIG_DIR if necessary
|
| 485 |
+
font = Path(font)
|
| 486 |
+
file = CONFIG_DIR / font.name
|
| 487 |
+
if not font.exists() and not file.exists():
|
| 488 |
+
url = f'https://ultralytics.com/assets/{font.name}'
|
| 489 |
+
LOGGER.info(f'Downloading {url} to {file}...')
|
| 490 |
+
torch.hub.download_url_to_file(url, str(file), progress=progress)
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def check_dataset(data, autodownload=True):
|
| 494 |
+
# Download, check and/or unzip dataset if not found locally
|
| 495 |
+
|
| 496 |
+
# Download (optional)
|
| 497 |
+
extract_dir = ''
|
| 498 |
+
if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
|
| 499 |
+
download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
|
| 500 |
+
data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
|
| 501 |
+
extract_dir, autodownload = data.parent, False
|
| 502 |
+
|
| 503 |
+
# Read yaml (optional)
|
| 504 |
+
if isinstance(data, (str, Path)):
|
| 505 |
+
data = yaml_load(data) # dictionary
|
| 506 |
+
|
| 507 |
+
# Checks
|
| 508 |
+
for k in 'train', 'val', 'names':
|
| 509 |
+
assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
|
| 510 |
+
if isinstance(data['names'], (list, tuple)): # old array format
|
| 511 |
+
data['names'] = dict(enumerate(data['names'])) # convert to dict
|
| 512 |
+
assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
|
| 513 |
+
data['nc'] = len(data['names'])
|
| 514 |
+
|
| 515 |
+
# Resolve paths
|
| 516 |
+
path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
|
| 517 |
+
if not path.is_absolute():
|
| 518 |
+
path = (ROOT / path).resolve()
|
| 519 |
+
data['path'] = path # download scripts
|
| 520 |
+
for k in 'train', 'val', 'test':
|
| 521 |
+
if data.get(k): # prepend path
|
| 522 |
+
if isinstance(data[k], str):
|
| 523 |
+
x = (path / data[k]).resolve()
|
| 524 |
+
if not x.exists() and data[k].startswith('../'):
|
| 525 |
+
x = (path / data[k][3:]).resolve()
|
| 526 |
+
data[k] = str(x)
|
| 527 |
+
else:
|
| 528 |
+
data[k] = [str((path / x).resolve()) for x in data[k]]
|
| 529 |
+
|
| 530 |
+
# Parse yaml
|
| 531 |
+
train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
|
| 532 |
+
if val:
|
| 533 |
+
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
| 534 |
+
if not all(x.exists() for x in val):
|
| 535 |
+
LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
|
| 536 |
+
if not s or not autodownload:
|
| 537 |
+
raise Exception('Dataset not found ❌')
|
| 538 |
+
t = time.time()
|
| 539 |
+
if s.startswith('http') and s.endswith('.zip'): # URL
|
| 540 |
+
f = Path(s).name # filename
|
| 541 |
+
LOGGER.info(f'Downloading {s} to {f}...')
|
| 542 |
+
torch.hub.download_url_to_file(s, f)
|
| 543 |
+
Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
|
| 544 |
+
unzip_file(f, path=DATASETS_DIR) # unzip
|
| 545 |
+
Path(f).unlink() # remove zip
|
| 546 |
+
r = None # success
|
| 547 |
+
elif s.startswith('bash '): # bash script
|
| 548 |
+
LOGGER.info(f'Running {s} ...')
|
| 549 |
+
r = os.system(s)
|
| 550 |
+
else: # python script
|
| 551 |
+
r = exec(s, {'yaml': data}) # return None
|
| 552 |
+
dt = f'({round(time.time() - t, 1)}s)'
|
| 553 |
+
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
| 554 |
+
LOGGER.info(f"Dataset download {s}")
|
| 555 |
+
check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
|
| 556 |
+
return data # dictionary
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def check_amp(model):
|
| 560 |
+
# Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
|
| 561 |
+
from models.common import AutoShape, DetectMultiBackend
|
| 562 |
+
|
| 563 |
+
def amp_allclose(model, im):
|
| 564 |
+
# All close FP32 vs AMP results
|
| 565 |
+
m = AutoShape(model, verbose=False) # model
|
| 566 |
+
a = m(im).xywhn[0] # FP32 inference
|
| 567 |
+
m.amp = True
|
| 568 |
+
b = m(im).xywhn[0] # AMP inference
|
| 569 |
+
return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
|
| 570 |
+
|
| 571 |
+
prefix = colorstr('AMP: ')
|
| 572 |
+
device = next(model.parameters()).device # get model device
|
| 573 |
+
if device.type in ('cpu', 'mps'):
|
| 574 |
+
return False # AMP only used on CUDA devices
|
| 575 |
+
f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
|
| 576 |
+
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
|
| 577 |
+
try:
|
| 578 |
+
#assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolo.pt', device), im)
|
| 579 |
+
LOGGER.info(f'{prefix}checks passed ✅')
|
| 580 |
+
return True
|
| 581 |
+
except Exception:
|
| 582 |
+
help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
|
| 583 |
+
LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
|
| 584 |
+
return False
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def yaml_load(file='data.yaml'):
|
| 588 |
+
# Single-line safe yaml loading
|
| 589 |
+
with open(file, errors='ignore') as f:
|
| 590 |
+
return yaml.safe_load(f)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def yaml_save(file='data.yaml', data={}):
|
| 594 |
+
# Single-line safe yaml saving
|
| 595 |
+
with open(file, 'w') as f:
|
| 596 |
+
yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
|
| 600 |
+
# Unzip a *.zip file to path/, excluding files containing strings in exclude list
|
| 601 |
+
if path is None:
|
| 602 |
+
path = Path(file).parent # default path
|
| 603 |
+
with ZipFile(file) as zipObj:
|
| 604 |
+
for f in zipObj.namelist(): # list all archived filenames in the zip
|
| 605 |
+
if all(x not in f for x in exclude):
|
| 606 |
+
zipObj.extract(f, path=path)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def url2file(url):
|
| 610 |
+
# Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
|
| 611 |
+
url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
|
| 612 |
+
return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
|
| 616 |
+
# Multithreaded file download and unzip function, used in data.yaml for autodownload
|
| 617 |
+
def download_one(url, dir):
|
| 618 |
+
# Download 1 file
|
| 619 |
+
success = True
|
| 620 |
+
if os.path.isfile(url):
|
| 621 |
+
f = Path(url) # filename
|
| 622 |
+
else: # does not exist
|
| 623 |
+
f = dir / Path(url).name
|
| 624 |
+
LOGGER.info(f'Downloading {url} to {f}...')
|
| 625 |
+
for i in range(retry + 1):
|
| 626 |
+
if curl:
|
| 627 |
+
s = 'sS' if threads > 1 else '' # silent
|
| 628 |
+
r = os.system(
|
| 629 |
+
f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
|
| 630 |
+
success = r == 0
|
| 631 |
+
else:
|
| 632 |
+
torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
|
| 633 |
+
success = f.is_file()
|
| 634 |
+
if success:
|
| 635 |
+
break
|
| 636 |
+
elif i < retry:
|
| 637 |
+
LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
|
| 638 |
+
else:
|
| 639 |
+
LOGGER.warning(f'❌ Failed to download {url}...')
|
| 640 |
+
|
| 641 |
+
if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
|
| 642 |
+
LOGGER.info(f'Unzipping {f}...')
|
| 643 |
+
if is_zipfile(f):
|
| 644 |
+
unzip_file(f, dir) # unzip
|
| 645 |
+
elif is_tarfile(f):
|
| 646 |
+
os.system(f'tar xf {f} --directory {f.parent}') # unzip
|
| 647 |
+
elif f.suffix == '.gz':
|
| 648 |
+
os.system(f'tar xfz {f} --directory {f.parent}') # unzip
|
| 649 |
+
if delete:
|
| 650 |
+
f.unlink() # remove zip
|
| 651 |
+
|
| 652 |
+
dir = Path(dir)
|
| 653 |
+
dir.mkdir(parents=True, exist_ok=True) # make directory
|
| 654 |
+
if threads > 1:
|
| 655 |
+
pool = ThreadPool(threads)
|
| 656 |
+
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
|
| 657 |
+
pool.close()
|
| 658 |
+
pool.join()
|
| 659 |
+
else:
|
| 660 |
+
for u in [url] if isinstance(url, (str, Path)) else url:
|
| 661 |
+
download_one(u, dir)
|
| 662 |
+
|
| 663 |
+
|
| 664 |
+
def make_divisible(x, divisor):
|
| 665 |
+
# Returns nearest x divisible by divisor
|
| 666 |
+
if isinstance(divisor, torch.Tensor):
|
| 667 |
+
divisor = int(divisor.max()) # to int
|
| 668 |
+
return math.ceil(x / divisor) * divisor
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def clean_str(s):
|
| 672 |
+
# Cleans a string by replacing special characters with underscore _
|
| 673 |
+
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def one_cycle(y1=0.0, y2=1.0, steps=100):
|
| 677 |
+
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
|
| 678 |
+
return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
|
| 682 |
+
# lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
|
| 683 |
+
#return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
|
| 684 |
+
return lambda x: ((1 - math.cos((x - (steps // 2)) * math.pi / (steps // 2))) / 2) * (y2 - y1) + y1 if (x > (steps // 2)) else y1
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def colorstr(*input):
|
| 688 |
+
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
|
| 689 |
+
*args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
|
| 690 |
+
colors = {
|
| 691 |
+
'black': '\033[30m', # basic colors
|
| 692 |
+
'red': '\033[31m',
|
| 693 |
+
'green': '\033[32m',
|
| 694 |
+
'yellow': '\033[33m',
|
| 695 |
+
'blue': '\033[34m',
|
| 696 |
+
'magenta': '\033[35m',
|
| 697 |
+
'cyan': '\033[36m',
|
| 698 |
+
'white': '\033[37m',
|
| 699 |
+
'bright_black': '\033[90m', # bright colors
|
| 700 |
+
'bright_red': '\033[91m',
|
| 701 |
+
'bright_green': '\033[92m',
|
| 702 |
+
'bright_yellow': '\033[93m',
|
| 703 |
+
'bright_blue': '\033[94m',
|
| 704 |
+
'bright_magenta': '\033[95m',
|
| 705 |
+
'bright_cyan': '\033[96m',
|
| 706 |
+
'bright_white': '\033[97m',
|
| 707 |
+
'end': '\033[0m', # misc
|
| 708 |
+
'bold': '\033[1m',
|
| 709 |
+
'underline': '\033[4m'}
|
| 710 |
+
return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def labels_to_class_weights(labels, nc=80):
|
| 714 |
+
# Get class weights (inverse frequency) from training labels
|
| 715 |
+
if labels[0] is None: # no labels loaded
|
| 716 |
+
return torch.Tensor()
|
| 717 |
+
|
| 718 |
+
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
|
| 719 |
+
classes = labels[:, 0].astype(int) # labels = [class xywh]
|
| 720 |
+
weights = np.bincount(classes, minlength=nc) # occurrences per class
|
| 721 |
+
|
| 722 |
+
# Prepend gridpoint count (for uCE training)
|
| 723 |
+
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
|
| 724 |
+
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
|
| 725 |
+
|
| 726 |
+
weights[weights == 0] = 1 # replace empty bins with 1
|
| 727 |
+
weights = 1 / weights # number of targets per class
|
| 728 |
+
weights /= weights.sum() # normalize
|
| 729 |
+
return torch.from_numpy(weights).float()
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
|
| 733 |
+
# Produces image weights based on class_weights and image contents
|
| 734 |
+
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
|
| 735 |
+
class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
|
| 736 |
+
return (class_weights.reshape(1, nc) * class_counts).sum(1)
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
|
| 740 |
+
# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
| 741 |
+
# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
|
| 742 |
+
# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
|
| 743 |
+
# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
|
| 744 |
+
# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
|
| 745 |
+
return [
|
| 746 |
+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
|
| 747 |
+
35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
|
| 748 |
+
64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
def xyxy2xywh(x):
|
| 752 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
|
| 753 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 754 |
+
y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
|
| 755 |
+
y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
|
| 756 |
+
y[..., 2] = x[..., 2] - x[..., 0] # width
|
| 757 |
+
y[..., 3] = x[..., 3] - x[..., 1] # height
|
| 758 |
+
return y
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def xywh2xyxy(x):
|
| 762 |
+
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 763 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 764 |
+
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
| 765 |
+
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
| 766 |
+
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
|
| 767 |
+
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
|
| 768 |
+
return y
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
|
| 772 |
+
# Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
| 773 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 774 |
+
y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
|
| 775 |
+
y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
|
| 776 |
+
y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
|
| 777 |
+
y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
|
| 778 |
+
return y
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
|
| 782 |
+
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
|
| 783 |
+
if clip:
|
| 784 |
+
clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
|
| 785 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 786 |
+
y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
|
| 787 |
+
y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
|
| 788 |
+
y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
|
| 789 |
+
y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
|
| 790 |
+
return y
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
def xyn2xy(x, w=640, h=640, padw=0, padh=0):
|
| 794 |
+
# Convert normalized segments into pixel segments, shape (n,2)
|
| 795 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 796 |
+
y[..., 0] = w * x[..., 0] + padw # top left x
|
| 797 |
+
y[..., 1] = h * x[..., 1] + padh # top left y
|
| 798 |
+
return y
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
def segment2box(segment, width=640, height=640):
|
| 802 |
+
# Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
|
| 803 |
+
x, y = segment.T # segment xy
|
| 804 |
+
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
| 805 |
+
x, y, = x[inside], y[inside]
|
| 806 |
+
return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
def segments2boxes(segments):
|
| 810 |
+
# Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
|
| 811 |
+
boxes = []
|
| 812 |
+
for s in segments:
|
| 813 |
+
x, y = s.T # segment xy
|
| 814 |
+
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
| 815 |
+
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def resample_segments(segments, n=1000):
|
| 819 |
+
# Up-sample an (n,2) segment
|
| 820 |
+
for i, s in enumerate(segments):
|
| 821 |
+
s = np.concatenate((s, s[0:1, :]), axis=0)
|
| 822 |
+
x = np.linspace(0, len(s) - 1, n)
|
| 823 |
+
xp = np.arange(len(s))
|
| 824 |
+
segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
|
| 825 |
+
return segments
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
|
| 829 |
+
# Rescale boxes (xyxy) from img1_shape to img0_shape
|
| 830 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 831 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 832 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 833 |
+
else:
|
| 834 |
+
gain = ratio_pad[0][0]
|
| 835 |
+
pad = ratio_pad[1]
|
| 836 |
+
|
| 837 |
+
boxes[:, [0, 2]] -= pad[0] # x padding
|
| 838 |
+
boxes[:, [1, 3]] -= pad[1] # y padding
|
| 839 |
+
boxes[:, :4] /= gain
|
| 840 |
+
clip_boxes(boxes, img0_shape)
|
| 841 |
+
return boxes
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
|
| 845 |
+
# Rescale coords (xyxy) from img1_shape to img0_shape
|
| 846 |
+
if ratio_pad is None: # calculate from img0_shape
|
| 847 |
+
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
| 848 |
+
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
|
| 849 |
+
else:
|
| 850 |
+
gain = ratio_pad[0][0]
|
| 851 |
+
pad = ratio_pad[1]
|
| 852 |
+
|
| 853 |
+
segments[:, 0] -= pad[0] # x padding
|
| 854 |
+
segments[:, 1] -= pad[1] # y padding
|
| 855 |
+
segments /= gain
|
| 856 |
+
clip_segments(segments, img0_shape)
|
| 857 |
+
if normalize:
|
| 858 |
+
segments[:, 0] /= img0_shape[1] # width
|
| 859 |
+
segments[:, 1] /= img0_shape[0] # height
|
| 860 |
+
return segments
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def clip_boxes(boxes, shape):
|
| 864 |
+
# Clip boxes (xyxy) to image shape (height, width)
|
| 865 |
+
if isinstance(boxes, torch.Tensor): # faster individually
|
| 866 |
+
boxes[:, 0].clamp_(0, shape[1]) # x1
|
| 867 |
+
boxes[:, 1].clamp_(0, shape[0]) # y1
|
| 868 |
+
boxes[:, 2].clamp_(0, shape[1]) # x2
|
| 869 |
+
boxes[:, 3].clamp_(0, shape[0]) # y2
|
| 870 |
+
else: # np.array (faster grouped)
|
| 871 |
+
boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
|
| 872 |
+
boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def clip_segments(segments, shape):
|
| 876 |
+
# Clip segments (xy1,xy2,...) to image shape (height, width)
|
| 877 |
+
if isinstance(segments, torch.Tensor): # faster individually
|
| 878 |
+
segments[:, 0].clamp_(0, shape[1]) # x
|
| 879 |
+
segments[:, 1].clamp_(0, shape[0]) # y
|
| 880 |
+
else: # np.array (faster grouped)
|
| 881 |
+
segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
|
| 882 |
+
segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
def non_max_suppression(
|
| 886 |
+
prediction,
|
| 887 |
+
conf_thres=0.25,
|
| 888 |
+
iou_thres=0.45,
|
| 889 |
+
classes=None,
|
| 890 |
+
agnostic=False,
|
| 891 |
+
multi_label=False,
|
| 892 |
+
labels=(),
|
| 893 |
+
max_det=300,
|
| 894 |
+
nm=0, # number of masks
|
| 895 |
+
):
|
| 896 |
+
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
|
| 897 |
+
|
| 898 |
+
Returns:
|
| 899 |
+
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
|
| 900 |
+
"""
|
| 901 |
+
|
| 902 |
+
if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
|
| 903 |
+
prediction = prediction[0] # select only inference output
|
| 904 |
+
|
| 905 |
+
device = prediction.device
|
| 906 |
+
mps = 'mps' in device.type # Apple MPS
|
| 907 |
+
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
|
| 908 |
+
prediction = prediction.cpu()
|
| 909 |
+
bs = prediction.shape[0] # batch size
|
| 910 |
+
nc = prediction.shape[1] - nm - 4 # number of classes
|
| 911 |
+
mi = 4 + nc # mask start index
|
| 912 |
+
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
| 913 |
+
|
| 914 |
+
# Checks
|
| 915 |
+
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
|
| 916 |
+
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
|
| 917 |
+
|
| 918 |
+
# Settings
|
| 919 |
+
# min_wh = 2 # (pixels) minimum box width and height
|
| 920 |
+
max_wh = 7680 # (pixels) maximum box width and height
|
| 921 |
+
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
|
| 922 |
+
time_limit = 2.5 + 0.05 * bs # seconds to quit after
|
| 923 |
+
redundant = True # require redundant detections
|
| 924 |
+
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
| 925 |
+
merge = False # use merge-NMS
|
| 926 |
+
|
| 927 |
+
t = time.time()
|
| 928 |
+
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
|
| 929 |
+
for xi, x in enumerate(prediction): # image index, image inference
|
| 930 |
+
# Apply constraints
|
| 931 |
+
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
| 932 |
+
x = x.T[xc[xi]] # confidence
|
| 933 |
+
|
| 934 |
+
# Cat apriori labels if autolabelling
|
| 935 |
+
if labels and len(labels[xi]):
|
| 936 |
+
lb = labels[xi]
|
| 937 |
+
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
|
| 938 |
+
v[:, :4] = lb[:, 1:5] # box
|
| 939 |
+
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
| 940 |
+
x = torch.cat((x, v), 0)
|
| 941 |
+
|
| 942 |
+
# If none remain process next image
|
| 943 |
+
if not x.shape[0]:
|
| 944 |
+
continue
|
| 945 |
+
|
| 946 |
+
# Detections matrix nx6 (xyxy, conf, cls)
|
| 947 |
+
box, cls, mask = x.split((4, nc, nm), 1)
|
| 948 |
+
box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
|
| 949 |
+
if multi_label:
|
| 950 |
+
i, j = (cls > conf_thres).nonzero(as_tuple=False).T
|
| 951 |
+
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
| 952 |
+
else: # best class only
|
| 953 |
+
conf, j = cls.max(1, keepdim=True)
|
| 954 |
+
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
|
| 955 |
+
|
| 956 |
+
# Filter by class
|
| 957 |
+
if classes is not None:
|
| 958 |
+
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
|
| 959 |
+
|
| 960 |
+
# Apply finite constraint
|
| 961 |
+
# if not torch.isfinite(x).all():
|
| 962 |
+
# x = x[torch.isfinite(x).all(1)]
|
| 963 |
+
|
| 964 |
+
# Check shape
|
| 965 |
+
n = x.shape[0] # number of boxes
|
| 966 |
+
if not n: # no boxes
|
| 967 |
+
continue
|
| 968 |
+
elif n > max_nms: # excess boxes
|
| 969 |
+
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
|
| 970 |
+
else:
|
| 971 |
+
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
|
| 972 |
+
|
| 973 |
+
# Batched NMS
|
| 974 |
+
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
| 975 |
+
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
|
| 976 |
+
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
|
| 977 |
+
if i.shape[0] > max_det: # limit detections
|
| 978 |
+
i = i[:max_det]
|
| 979 |
+
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
|
| 980 |
+
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
|
| 981 |
+
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
|
| 982 |
+
weights = iou * scores[None] # box weights
|
| 983 |
+
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
|
| 984 |
+
if redundant:
|
| 985 |
+
i = i[iou.sum(1) > 1] # require redundancy
|
| 986 |
+
|
| 987 |
+
output[xi] = x[i]
|
| 988 |
+
if mps:
|
| 989 |
+
output[xi] = output[xi].to(device)
|
| 990 |
+
if (time.time() - t) > time_limit:
|
| 991 |
+
LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
|
| 992 |
+
break # time limit exceeded
|
| 993 |
+
|
| 994 |
+
return output
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
|
| 998 |
+
# Strip optimizer from 'f' to finalize training, optionally save as 's'
|
| 999 |
+
x = torch.load(f, map_location=torch.device('cpu'))
|
| 1000 |
+
if x.get('ema'):
|
| 1001 |
+
x['model'] = x['ema'] # replace model with ema
|
| 1002 |
+
for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
|
| 1003 |
+
x[k] = None
|
| 1004 |
+
x['epoch'] = -1
|
| 1005 |
+
x['model'].half() # to FP16
|
| 1006 |
+
for p in x['model'].parameters():
|
| 1007 |
+
p.requires_grad = False
|
| 1008 |
+
torch.save(x, s or f)
|
| 1009 |
+
mb = os.path.getsize(s or f) / 1E6 # filesize
|
| 1010 |
+
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
|
| 1014 |
+
evolve_csv = save_dir / 'evolve.csv'
|
| 1015 |
+
evolve_yaml = save_dir / 'hyp_evolve.yaml'
|
| 1016 |
+
keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
|
| 1017 |
+
keys = tuple(x.strip() for x in keys)
|
| 1018 |
+
vals = results + tuple(hyp.values())
|
| 1019 |
+
n = len(keys)
|
| 1020 |
+
|
| 1021 |
+
# Download (optional)
|
| 1022 |
+
if bucket:
|
| 1023 |
+
url = f'gs://{bucket}/evolve.csv'
|
| 1024 |
+
if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
|
| 1025 |
+
os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
|
| 1026 |
+
|
| 1027 |
+
# Log to evolve.csv
|
| 1028 |
+
s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
|
| 1029 |
+
with open(evolve_csv, 'a') as f:
|
| 1030 |
+
f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
|
| 1031 |
+
|
| 1032 |
+
# Save yaml
|
| 1033 |
+
with open(evolve_yaml, 'w') as f:
|
| 1034 |
+
data = pd.read_csv(evolve_csv)
|
| 1035 |
+
data = data.rename(columns=lambda x: x.strip()) # strip keys
|
| 1036 |
+
i = np.argmax(fitness(data.values[:, :4])) #
|
| 1037 |
+
generations = len(data)
|
| 1038 |
+
f.write('# YOLO Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
|
| 1039 |
+
f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
|
| 1040 |
+
'\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
|
| 1041 |
+
yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
|
| 1042 |
+
|
| 1043 |
+
# Print to screen
|
| 1044 |
+
LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
|
| 1045 |
+
', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
|
| 1046 |
+
for x in vals) + '\n\n')
|
| 1047 |
+
|
| 1048 |
+
if bucket:
|
| 1049 |
+
os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def apply_classifier(x, model, img, im0):
|
| 1053 |
+
# Apply a second stage classifier to YOLO outputs
|
| 1054 |
+
# Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
|
| 1055 |
+
im0 = [im0] if isinstance(im0, np.ndarray) else im0
|
| 1056 |
+
for i, d in enumerate(x): # per image
|
| 1057 |
+
if d is not None and len(d):
|
| 1058 |
+
d = d.clone()
|
| 1059 |
+
|
| 1060 |
+
# Reshape and pad cutouts
|
| 1061 |
+
b = xyxy2xywh(d[:, :4]) # boxes
|
| 1062 |
+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
|
| 1063 |
+
b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
|
| 1064 |
+
d[:, :4] = xywh2xyxy(b).long()
|
| 1065 |
+
|
| 1066 |
+
# Rescale boxes from img_size to im0 size
|
| 1067 |
+
scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
|
| 1068 |
+
|
| 1069 |
+
# Classes
|
| 1070 |
+
pred_cls1 = d[:, 5].long()
|
| 1071 |
+
ims = []
|
| 1072 |
+
for a in d:
|
| 1073 |
+
cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
|
| 1074 |
+
im = cv2.resize(cutout, (224, 224)) # BGR
|
| 1075 |
+
|
| 1076 |
+
im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
| 1077 |
+
im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
|
| 1078 |
+
im /= 255 # 0 - 255 to 0.0 - 1.0
|
| 1079 |
+
ims.append(im)
|
| 1080 |
+
|
| 1081 |
+
pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
|
| 1082 |
+
x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
|
| 1083 |
+
|
| 1084 |
+
return x
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
def increment_path(path, exist_ok=False, sep='', mkdir=False):
|
| 1088 |
+
# Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
| 1089 |
+
path = Path(path) # os-agnostic
|
| 1090 |
+
if path.exists() and not exist_ok:
|
| 1091 |
+
path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
|
| 1092 |
+
|
| 1093 |
+
# Method 1
|
| 1094 |
+
for n in range(2, 9999):
|
| 1095 |
+
p = f'{path}{sep}{n}{suffix}' # increment path
|
| 1096 |
+
if not os.path.exists(p): #
|
| 1097 |
+
break
|
| 1098 |
+
path = Path(p)
|
| 1099 |
+
|
| 1100 |
+
# Method 2 (deprecated)
|
| 1101 |
+
# dirs = glob.glob(f"{path}{sep}*") # similar paths
|
| 1102 |
+
# matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
|
| 1103 |
+
# i = [int(m.groups()[0]) for m in matches if m] # indices
|
| 1104 |
+
# n = max(i) + 1 if i else 2 # increment number
|
| 1105 |
+
# path = Path(f"{path}{sep}{n}{suffix}") # increment path
|
| 1106 |
+
|
| 1107 |
+
if mkdir:
|
| 1108 |
+
path.mkdir(parents=True, exist_ok=True) # make directory
|
| 1109 |
+
|
| 1110 |
+
return path
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
# OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
|
| 1114 |
+
imshow_ = cv2.imshow # copy to avoid recursion errors
|
| 1115 |
+
|
| 1116 |
+
|
| 1117 |
+
def imread(path, flags=cv2.IMREAD_COLOR):
|
| 1118 |
+
return cv2.imdecode(np.fromfile(path, np.uint8), flags)
|
| 1119 |
+
|
| 1120 |
+
|
| 1121 |
+
def imwrite(path, im):
|
| 1122 |
+
try:
|
| 1123 |
+
cv2.imencode(Path(path).suffix, im)[1].tofile(path)
|
| 1124 |
+
return True
|
| 1125 |
+
except Exception:
|
| 1126 |
+
return False
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
def imshow(path, im):
|
| 1130 |
+
imshow_(path.encode('unicode_escape').decode(), im)
|
| 1131 |
+
|
| 1132 |
+
|
| 1133 |
+
cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
|
| 1134 |
+
|
| 1135 |
+
# Variables ------------------------------------------------------------------------------------------------------------
|
plots.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from copy import copy
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from urllib.error import URLError
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import matplotlib
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import seaborn as sn
|
| 14 |
+
import torch
|
| 15 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 16 |
+
|
| 17 |
+
from utils import TryExcept, threaded
|
| 18 |
+
from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
|
| 19 |
+
is_ascii, xywh2xyxy, xyxy2xywh)
|
| 20 |
+
from utils.metrics import fitness
|
| 21 |
+
from utils.segment.general import scale_image
|
| 22 |
+
|
| 23 |
+
# Settings
|
| 24 |
+
RANK = int(os.getenv('RANK', -1))
|
| 25 |
+
matplotlib.rc('font', **{'size': 11})
|
| 26 |
+
matplotlib.use('Agg') # for writing to files only
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Colors:
|
| 30 |
+
# Ultralytics color palette https://ultralytics.com/
|
| 31 |
+
def __init__(self):
|
| 32 |
+
# hex = matplotlib.colors.TABLEAU_COLORS.values()
|
| 33 |
+
hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
|
| 34 |
+
'2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
|
| 35 |
+
self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
|
| 36 |
+
self.n = len(self.palette)
|
| 37 |
+
|
| 38 |
+
def __call__(self, i, bgr=False):
|
| 39 |
+
c = self.palette[int(i) % self.n]
|
| 40 |
+
return (c[2], c[1], c[0]) if bgr else c
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def hex2rgb(h): # rgb order (PIL)
|
| 44 |
+
return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
colors = Colors() # create instance for 'from utils.plots import colors'
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def check_pil_font(font=FONT, size=10):
|
| 51 |
+
# Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
|
| 52 |
+
font = Path(font)
|
| 53 |
+
font = font if font.exists() else (CONFIG_DIR / font.name)
|
| 54 |
+
try:
|
| 55 |
+
return ImageFont.truetype(str(font) if font.exists() else font.name, size)
|
| 56 |
+
except Exception: # download if missing
|
| 57 |
+
try:
|
| 58 |
+
check_font(font)
|
| 59 |
+
return ImageFont.truetype(str(font), size)
|
| 60 |
+
except TypeError:
|
| 61 |
+
check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
|
| 62 |
+
except URLError: # not online
|
| 63 |
+
return ImageFont.load_default()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Annotator:
|
| 67 |
+
# YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
|
| 68 |
+
def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
|
| 69 |
+
assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
|
| 70 |
+
non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
|
| 71 |
+
self.pil = pil or non_ascii
|
| 72 |
+
if self.pil: # use PIL
|
| 73 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
| 74 |
+
self.draw = ImageDraw.Draw(self.im)
|
| 75 |
+
self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
|
| 76 |
+
size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
|
| 77 |
+
else: # use cv2
|
| 78 |
+
self.im = im
|
| 79 |
+
self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
|
| 80 |
+
|
| 81 |
+
def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
|
| 82 |
+
# Add one xyxy box to image with label
|
| 83 |
+
if self.pil or not is_ascii(label):
|
| 84 |
+
self.draw.rectangle(box, width=self.lw, outline=color) # box
|
| 85 |
+
if label:
|
| 86 |
+
w, h = self.font.getsize(label) # text width, height
|
| 87 |
+
outside = box[1] - h >= 0 # label fits outside box
|
| 88 |
+
self.draw.rectangle(
|
| 89 |
+
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
|
| 90 |
+
box[1] + 1 if outside else box[1] + h + 1),
|
| 91 |
+
fill=color,
|
| 92 |
+
)
|
| 93 |
+
# self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
|
| 94 |
+
self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
|
| 95 |
+
else: # cv2
|
| 96 |
+
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
|
| 97 |
+
cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
|
| 98 |
+
if label:
|
| 99 |
+
tf = max(self.lw - 1, 1) # font thickness
|
| 100 |
+
w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
|
| 101 |
+
outside = p1[1] - h >= 3
|
| 102 |
+
p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
|
| 103 |
+
cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
|
| 104 |
+
cv2.putText(self.im,
|
| 105 |
+
label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
|
| 106 |
+
0,
|
| 107 |
+
self.lw / 3,
|
| 108 |
+
txt_color,
|
| 109 |
+
thickness=tf,
|
| 110 |
+
lineType=cv2.LINE_AA)
|
| 111 |
+
|
| 112 |
+
def masks(self, masks, colors, im_gpu=None, alpha=0.5):
|
| 113 |
+
"""Plot masks at once.
|
| 114 |
+
Args:
|
| 115 |
+
masks (tensor): predicted masks on cuda, shape: [n, h, w]
|
| 116 |
+
colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
|
| 117 |
+
im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
|
| 118 |
+
alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
|
| 119 |
+
"""
|
| 120 |
+
if self.pil:
|
| 121 |
+
# convert to numpy first
|
| 122 |
+
self.im = np.asarray(self.im).copy()
|
| 123 |
+
if im_gpu is None:
|
| 124 |
+
# Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
|
| 125 |
+
if len(masks) == 0:
|
| 126 |
+
return
|
| 127 |
+
if isinstance(masks, torch.Tensor):
|
| 128 |
+
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
| 129 |
+
masks = masks.permute(1, 2, 0).contiguous()
|
| 130 |
+
masks = masks.cpu().numpy()
|
| 131 |
+
# masks = np.ascontiguousarray(masks.transpose(1, 2, 0))
|
| 132 |
+
masks = scale_image(masks.shape[:2], masks, self.im.shape)
|
| 133 |
+
masks = np.asarray(masks, dtype=np.float32)
|
| 134 |
+
colors = np.asarray(colors, dtype=np.float32) # shape(n,3)
|
| 135 |
+
s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together
|
| 136 |
+
masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3)
|
| 137 |
+
self.im[:] = masks * alpha + self.im * (1 - s * alpha)
|
| 138 |
+
else:
|
| 139 |
+
if len(masks) == 0:
|
| 140 |
+
self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
|
| 141 |
+
colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
|
| 142 |
+
colors = colors[:, None, None] # shape(n,1,1,3)
|
| 143 |
+
masks = masks.unsqueeze(3) # shape(n,h,w,1)
|
| 144 |
+
masks_color = masks * (colors * alpha) # shape(n,h,w,3)
|
| 145 |
+
|
| 146 |
+
inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
|
| 147 |
+
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
|
| 148 |
+
|
| 149 |
+
im_gpu = im_gpu.flip(dims=[0]) # flip channel
|
| 150 |
+
im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
|
| 151 |
+
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
|
| 152 |
+
im_mask = (im_gpu * 255).byte().cpu().numpy()
|
| 153 |
+
self.im[:] = scale_image(im_gpu.shape, im_mask, self.im.shape)
|
| 154 |
+
if self.pil:
|
| 155 |
+
# convert im back to PIL and update draw
|
| 156 |
+
self.fromarray(self.im)
|
| 157 |
+
|
| 158 |
+
def rectangle(self, xy, fill=None, outline=None, width=1):
|
| 159 |
+
# Add rectangle to image (PIL-only)
|
| 160 |
+
self.draw.rectangle(xy, fill, outline, width)
|
| 161 |
+
|
| 162 |
+
def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
|
| 163 |
+
# Add text to image (PIL-only)
|
| 164 |
+
if anchor == 'bottom': # start y from font bottom
|
| 165 |
+
w, h = self.font.getsize(text) # text width, height
|
| 166 |
+
xy[1] += 1 - h
|
| 167 |
+
self.draw.text(xy, text, fill=txt_color, font=self.font)
|
| 168 |
+
|
| 169 |
+
def fromarray(self, im):
|
| 170 |
+
# Update self.im from a numpy array
|
| 171 |
+
self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
|
| 172 |
+
self.draw = ImageDraw.Draw(self.im)
|
| 173 |
+
|
| 174 |
+
def result(self):
|
| 175 |
+
# Return annotated image as array
|
| 176 |
+
return np.asarray(self.im)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
|
| 180 |
+
"""
|
| 181 |
+
x: Features to be visualized
|
| 182 |
+
module_type: Module type
|
| 183 |
+
stage: Module stage within model
|
| 184 |
+
n: Maximum number of feature maps to plot
|
| 185 |
+
save_dir: Directory to save results
|
| 186 |
+
"""
|
| 187 |
+
if 'Detect' not in module_type:
|
| 188 |
+
batch, channels, height, width = x.shape # batch, channels, height, width
|
| 189 |
+
if height > 1 and width > 1:
|
| 190 |
+
f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
|
| 191 |
+
|
| 192 |
+
blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
|
| 193 |
+
n = min(n, channels) # number of plots
|
| 194 |
+
fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
|
| 195 |
+
ax = ax.ravel()
|
| 196 |
+
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
| 197 |
+
for i in range(n):
|
| 198 |
+
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
|
| 199 |
+
ax[i].axis('off')
|
| 200 |
+
|
| 201 |
+
LOGGER.info(f'Saving {f}... ({n}/{channels})')
|
| 202 |
+
plt.savefig(f, dpi=300, bbox_inches='tight')
|
| 203 |
+
plt.close()
|
| 204 |
+
np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def hist2d(x, y, n=100):
|
| 208 |
+
# 2d histogram used in labels.png and evolve.png
|
| 209 |
+
xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
|
| 210 |
+
hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
|
| 211 |
+
xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
|
| 212 |
+
yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
|
| 213 |
+
return np.log(hist[xidx, yidx])
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
|
| 217 |
+
from scipy.signal import butter, filtfilt
|
| 218 |
+
|
| 219 |
+
# https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
|
| 220 |
+
def butter_lowpass(cutoff, fs, order):
|
| 221 |
+
nyq = 0.5 * fs
|
| 222 |
+
normal_cutoff = cutoff / nyq
|
| 223 |
+
return butter(order, normal_cutoff, btype='low', analog=False)
|
| 224 |
+
|
| 225 |
+
b, a = butter_lowpass(cutoff, fs, order=order)
|
| 226 |
+
return filtfilt(b, a, data) # forward-backward filter
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def output_to_target(output, max_det=300):
|
| 230 |
+
# Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
|
| 231 |
+
targets = []
|
| 232 |
+
for i, o in enumerate(output):
|
| 233 |
+
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
|
| 234 |
+
j = torch.full((conf.shape[0], 1), i)
|
| 235 |
+
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
|
| 236 |
+
return torch.cat(targets, 0).numpy()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@threaded
|
| 240 |
+
def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
|
| 241 |
+
# Plot image grid with labels
|
| 242 |
+
if isinstance(images, torch.Tensor):
|
| 243 |
+
images = images.cpu().float().numpy()
|
| 244 |
+
if isinstance(targets, torch.Tensor):
|
| 245 |
+
targets = targets.cpu().numpy()
|
| 246 |
+
|
| 247 |
+
max_size = 1920 # max image size
|
| 248 |
+
max_subplots = 16 # max image subplots, i.e. 4x4
|
| 249 |
+
bs, _, h, w = images.shape # batch size, _, height, width
|
| 250 |
+
bs = min(bs, max_subplots) # limit plot images
|
| 251 |
+
ns = np.ceil(bs ** 0.5) # number of subplots (square)
|
| 252 |
+
if np.max(images[0]) <= 1:
|
| 253 |
+
images *= 255 # de-normalise (optional)
|
| 254 |
+
|
| 255 |
+
# Build Image
|
| 256 |
+
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
|
| 257 |
+
for i, im in enumerate(images):
|
| 258 |
+
if i == max_subplots: # if last batch has fewer images than we expect
|
| 259 |
+
break
|
| 260 |
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
| 261 |
+
im = im.transpose(1, 2, 0)
|
| 262 |
+
mosaic[y:y + h, x:x + w, :] = im
|
| 263 |
+
|
| 264 |
+
# Resize (optional)
|
| 265 |
+
scale = max_size / ns / max(h, w)
|
| 266 |
+
if scale < 1:
|
| 267 |
+
h = math.ceil(scale * h)
|
| 268 |
+
w = math.ceil(scale * w)
|
| 269 |
+
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
|
| 270 |
+
|
| 271 |
+
# Annotate
|
| 272 |
+
fs = int((h + w) * ns * 0.01) # font size
|
| 273 |
+
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
|
| 274 |
+
for i in range(i + 1):
|
| 275 |
+
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
|
| 276 |
+
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
|
| 277 |
+
if paths:
|
| 278 |
+
annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
|
| 279 |
+
if len(targets) > 0:
|
| 280 |
+
ti = targets[targets[:, 0] == i] # image targets
|
| 281 |
+
boxes = xywh2xyxy(ti[:, 2:6]).T
|
| 282 |
+
classes = ti[:, 1].astype('int')
|
| 283 |
+
labels = ti.shape[1] == 6 # labels if no conf column
|
| 284 |
+
conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
|
| 285 |
+
|
| 286 |
+
if boxes.shape[1]:
|
| 287 |
+
if boxes.max() <= 1.01: # if normalized with tolerance 0.01
|
| 288 |
+
boxes[[0, 2]] *= w # scale to pixels
|
| 289 |
+
boxes[[1, 3]] *= h
|
| 290 |
+
elif scale < 1: # absolute coords need scale if image scales
|
| 291 |
+
boxes *= scale
|
| 292 |
+
boxes[[0, 2]] += x
|
| 293 |
+
boxes[[1, 3]] += y
|
| 294 |
+
for j, box in enumerate(boxes.T.tolist()):
|
| 295 |
+
cls = classes[j]
|
| 296 |
+
color = colors(cls)
|
| 297 |
+
cls = names[cls] if names else cls
|
| 298 |
+
if labels or conf[j] > 0.25: # 0.25 conf thresh
|
| 299 |
+
label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
|
| 300 |
+
annotator.box_label(box, label, color=color)
|
| 301 |
+
annotator.im.save(fname) # save
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
|
| 305 |
+
# Plot LR simulating training for full epochs
|
| 306 |
+
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
|
| 307 |
+
y = []
|
| 308 |
+
for _ in range(epochs):
|
| 309 |
+
scheduler.step()
|
| 310 |
+
y.append(optimizer.param_groups[0]['lr'])
|
| 311 |
+
plt.plot(y, '.-', label='LR')
|
| 312 |
+
plt.xlabel('epoch')
|
| 313 |
+
plt.ylabel('LR')
|
| 314 |
+
plt.grid()
|
| 315 |
+
plt.xlim(0, epochs)
|
| 316 |
+
plt.ylim(0)
|
| 317 |
+
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
|
| 318 |
+
plt.close()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def plot_val_txt(): # from utils.plots import *; plot_val()
|
| 322 |
+
# Plot val.txt histograms
|
| 323 |
+
x = np.loadtxt('val.txt', dtype=np.float32)
|
| 324 |
+
box = xyxy2xywh(x[:, :4])
|
| 325 |
+
cx, cy = box[:, 0], box[:, 1]
|
| 326 |
+
|
| 327 |
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
|
| 328 |
+
ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
|
| 329 |
+
ax.set_aspect('equal')
|
| 330 |
+
plt.savefig('hist2d.png', dpi=300)
|
| 331 |
+
|
| 332 |
+
fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
|
| 333 |
+
ax[0].hist(cx, bins=600)
|
| 334 |
+
ax[1].hist(cy, bins=600)
|
| 335 |
+
plt.savefig('hist1d.png', dpi=200)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
|
| 339 |
+
# Plot targets.txt histograms
|
| 340 |
+
x = np.loadtxt('targets.txt', dtype=np.float32).T
|
| 341 |
+
s = ['x targets', 'y targets', 'width targets', 'height targets']
|
| 342 |
+
fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
|
| 343 |
+
ax = ax.ravel()
|
| 344 |
+
for i in range(4):
|
| 345 |
+
ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
|
| 346 |
+
ax[i].legend()
|
| 347 |
+
ax[i].set_title(s[i])
|
| 348 |
+
plt.savefig('targets.jpg', dpi=200)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
|
| 352 |
+
# Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
|
| 353 |
+
save_dir = Path(file).parent if file else Path(dir)
|
| 354 |
+
plot2 = False # plot additional results
|
| 355 |
+
if plot2:
|
| 356 |
+
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
|
| 357 |
+
|
| 358 |
+
fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
|
| 359 |
+
# for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
|
| 360 |
+
for f in sorted(save_dir.glob('study*.txt')):
|
| 361 |
+
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
|
| 362 |
+
x = np.arange(y.shape[1]) if x is None else np.array(x)
|
| 363 |
+
if plot2:
|
| 364 |
+
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
|
| 365 |
+
for i in range(7):
|
| 366 |
+
ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
|
| 367 |
+
ax[i].set_title(s[i])
|
| 368 |
+
|
| 369 |
+
j = y[3].argmax() + 1
|
| 370 |
+
ax2.plot(y[5, 1:j],
|
| 371 |
+
y[3, 1:j] * 1E2,
|
| 372 |
+
'.-',
|
| 373 |
+
linewidth=2,
|
| 374 |
+
markersize=8,
|
| 375 |
+
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
|
| 376 |
+
|
| 377 |
+
ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
|
| 378 |
+
'k.-',
|
| 379 |
+
linewidth=2,
|
| 380 |
+
markersize=8,
|
| 381 |
+
alpha=.25,
|
| 382 |
+
label='EfficientDet')
|
| 383 |
+
|
| 384 |
+
ax2.grid(alpha=0.2)
|
| 385 |
+
ax2.set_yticks(np.arange(20, 60, 5))
|
| 386 |
+
ax2.set_xlim(0, 57)
|
| 387 |
+
ax2.set_ylim(25, 55)
|
| 388 |
+
ax2.set_xlabel('GPU Speed (ms/img)')
|
| 389 |
+
ax2.set_ylabel('COCO AP val')
|
| 390 |
+
ax2.legend(loc='lower right')
|
| 391 |
+
f = save_dir / 'study.png'
|
| 392 |
+
print(f'Saving {f}...')
|
| 393 |
+
plt.savefig(f, dpi=300)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
|
| 397 |
+
def plot_labels(labels, names=(), save_dir=Path('')):
|
| 398 |
+
# plot dataset labels
|
| 399 |
+
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
|
| 400 |
+
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
|
| 401 |
+
nc = int(c.max() + 1) # number of classes
|
| 402 |
+
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
|
| 403 |
+
|
| 404 |
+
# seaborn correlogram
|
| 405 |
+
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
|
| 406 |
+
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
|
| 407 |
+
plt.close()
|
| 408 |
+
|
| 409 |
+
# matplotlib labels
|
| 410 |
+
matplotlib.use('svg') # faster
|
| 411 |
+
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
|
| 412 |
+
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
|
| 413 |
+
with contextlib.suppress(Exception): # color histogram bars by class
|
| 414 |
+
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
|
| 415 |
+
ax[0].set_ylabel('instances')
|
| 416 |
+
if 0 < len(names) < 30:
|
| 417 |
+
ax[0].set_xticks(range(len(names)))
|
| 418 |
+
ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
|
| 419 |
+
else:
|
| 420 |
+
ax[0].set_xlabel('classes')
|
| 421 |
+
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
|
| 422 |
+
sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
|
| 423 |
+
|
| 424 |
+
# rectangles
|
| 425 |
+
labels[:, 1:3] = 0.5 # center
|
| 426 |
+
labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
|
| 427 |
+
img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
|
| 428 |
+
for cls, *box in labels[:1000]:
|
| 429 |
+
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
|
| 430 |
+
ax[1].imshow(img)
|
| 431 |
+
ax[1].axis('off')
|
| 432 |
+
|
| 433 |
+
for a in [0, 1, 2, 3]:
|
| 434 |
+
for s in ['top', 'right', 'left', 'bottom']:
|
| 435 |
+
ax[a].spines[s].set_visible(False)
|
| 436 |
+
|
| 437 |
+
plt.savefig(save_dir / 'labels.jpg', dpi=200)
|
| 438 |
+
matplotlib.use('Agg')
|
| 439 |
+
plt.close()
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
|
| 443 |
+
# Show classification image grid with labels (optional) and predictions (optional)
|
| 444 |
+
from utils.augmentations import denormalize
|
| 445 |
+
|
| 446 |
+
names = names or [f'class{i}' for i in range(1000)]
|
| 447 |
+
blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
|
| 448 |
+
dim=0) # select batch index 0, block by channels
|
| 449 |
+
n = min(len(blocks), nmax) # number of plots
|
| 450 |
+
m = min(8, round(n ** 0.5)) # 8 x 8 default
|
| 451 |
+
fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
|
| 452 |
+
ax = ax.ravel() if m > 1 else [ax]
|
| 453 |
+
# plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
| 454 |
+
for i in range(n):
|
| 455 |
+
ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
|
| 456 |
+
ax[i].axis('off')
|
| 457 |
+
if labels is not None:
|
| 458 |
+
s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
|
| 459 |
+
ax[i].set_title(s, fontsize=8, verticalalignment='top')
|
| 460 |
+
plt.savefig(f, dpi=300, bbox_inches='tight')
|
| 461 |
+
plt.close()
|
| 462 |
+
if verbose:
|
| 463 |
+
LOGGER.info(f"Saving {f}")
|
| 464 |
+
if labels is not None:
|
| 465 |
+
LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
|
| 466 |
+
if pred is not None:
|
| 467 |
+
LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
|
| 468 |
+
return f
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
|
| 472 |
+
# Plot evolve.csv hyp evolution results
|
| 473 |
+
evolve_csv = Path(evolve_csv)
|
| 474 |
+
data = pd.read_csv(evolve_csv)
|
| 475 |
+
keys = [x.strip() for x in data.columns]
|
| 476 |
+
x = data.values
|
| 477 |
+
f = fitness(x)
|
| 478 |
+
j = np.argmax(f) # max fitness index
|
| 479 |
+
plt.figure(figsize=(10, 12), tight_layout=True)
|
| 480 |
+
matplotlib.rc('font', **{'size': 8})
|
| 481 |
+
print(f'Best results from row {j} of {evolve_csv}:')
|
| 482 |
+
for i, k in enumerate(keys[7:]):
|
| 483 |
+
v = x[:, 7 + i]
|
| 484 |
+
mu = v[j] # best single result
|
| 485 |
+
plt.subplot(6, 5, i + 1)
|
| 486 |
+
plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
|
| 487 |
+
plt.plot(mu, f.max(), 'k+', markersize=15)
|
| 488 |
+
plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
|
| 489 |
+
if i % 5 != 0:
|
| 490 |
+
plt.yticks([])
|
| 491 |
+
print(f'{k:>15}: {mu:.3g}')
|
| 492 |
+
f = evolve_csv.with_suffix('.png') # filename
|
| 493 |
+
plt.savefig(f, dpi=200)
|
| 494 |
+
plt.close()
|
| 495 |
+
print(f'Saved {f}')
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def plot_results(file='path/to/results.csv', dir=''):
|
| 499 |
+
# Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
|
| 500 |
+
save_dir = Path(file).parent if file else Path(dir)
|
| 501 |
+
fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
|
| 502 |
+
ax = ax.ravel()
|
| 503 |
+
files = list(save_dir.glob('results*.csv'))
|
| 504 |
+
assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
|
| 505 |
+
for f in files:
|
| 506 |
+
try:
|
| 507 |
+
data = pd.read_csv(f)
|
| 508 |
+
s = [x.strip() for x in data.columns]
|
| 509 |
+
x = data.values[:, 0]
|
| 510 |
+
for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
|
| 511 |
+
y = data.values[:, j].astype('float')
|
| 512 |
+
# y[y == 0] = np.nan # don't show zero values
|
| 513 |
+
ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
|
| 514 |
+
ax[i].set_title(s[j], fontsize=12)
|
| 515 |
+
# if j in [8, 9, 10]: # share train and val loss y axes
|
| 516 |
+
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
| 517 |
+
except Exception as e:
|
| 518 |
+
LOGGER.info(f'Warning: Plotting error for {f}: {e}')
|
| 519 |
+
ax[1].legend()
|
| 520 |
+
fig.savefig(save_dir / 'results.png', dpi=200)
|
| 521 |
+
plt.close()
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
|
| 525 |
+
# Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
|
| 526 |
+
ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
|
| 527 |
+
s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
|
| 528 |
+
files = list(Path(save_dir).glob('frames*.txt'))
|
| 529 |
+
for fi, f in enumerate(files):
|
| 530 |
+
try:
|
| 531 |
+
results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
|
| 532 |
+
n = results.shape[1] # number of rows
|
| 533 |
+
x = np.arange(start, min(stop, n) if stop else n)
|
| 534 |
+
results = results[:, x]
|
| 535 |
+
t = (results[0] - results[0].min()) # set t0=0s
|
| 536 |
+
results[0] = x
|
| 537 |
+
for i, a in enumerate(ax):
|
| 538 |
+
if i < len(results):
|
| 539 |
+
label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
|
| 540 |
+
a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
|
| 541 |
+
a.set_title(s[i])
|
| 542 |
+
a.set_xlabel('time (s)')
|
| 543 |
+
# if fi == len(files) - 1:
|
| 544 |
+
# a.set_ylim(bottom=0)
|
| 545 |
+
for side in ['top', 'right']:
|
| 546 |
+
a.spines[side].set_visible(False)
|
| 547 |
+
else:
|
| 548 |
+
a.remove()
|
| 549 |
+
except Exception as e:
|
| 550 |
+
print(f'Warning: Plotting error for {f}; {e}')
|
| 551 |
+
ax[1].legend()
|
| 552 |
+
plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
|
| 556 |
+
# Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
|
| 557 |
+
xyxy = torch.tensor(xyxy).view(-1, 4)
|
| 558 |
+
b = xyxy2xywh(xyxy) # boxes
|
| 559 |
+
if square:
|
| 560 |
+
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
|
| 561 |
+
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
|
| 562 |
+
xyxy = xywh2xyxy(b).long()
|
| 563 |
+
clip_boxes(xyxy, im.shape)
|
| 564 |
+
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
|
| 565 |
+
if save:
|
| 566 |
+
file.parent.mkdir(parents=True, exist_ok=True) # make directory
|
| 567 |
+
f = str(increment_path(file).with_suffix('.jpg'))
|
| 568 |
+
# cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
|
| 569 |
+
Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
|
| 570 |
+
return crop
|
torch_utils.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import platform
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
import warnings
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.distributed as dist
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 16 |
+
|
| 17 |
+
from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
|
| 18 |
+
from utils.lion import Lion
|
| 19 |
+
|
| 20 |
+
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
|
| 21 |
+
RANK = int(os.getenv('RANK', -1))
|
| 22 |
+
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
import thop # for FLOPs computation
|
| 26 |
+
except ImportError:
|
| 27 |
+
thop = None
|
| 28 |
+
|
| 29 |
+
# Suppress PyTorch warnings
|
| 30 |
+
warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
| 31 |
+
warnings.filterwarnings('ignore', category=UserWarning)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
|
| 35 |
+
# Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
|
| 36 |
+
def decorate(fn):
|
| 37 |
+
return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
|
| 38 |
+
|
| 39 |
+
return decorate
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def smartCrossEntropyLoss(label_smoothing=0.0):
|
| 43 |
+
# Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
|
| 44 |
+
if check_version(torch.__version__, '1.10.0'):
|
| 45 |
+
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
| 46 |
+
if label_smoothing > 0:
|
| 47 |
+
LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
|
| 48 |
+
return nn.CrossEntropyLoss()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def smart_DDP(model):
|
| 52 |
+
# Model DDP creation with checks
|
| 53 |
+
assert not check_version(torch.__version__, '1.12.0', pinned=True), \
|
| 54 |
+
'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
|
| 55 |
+
'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
|
| 56 |
+
if check_version(torch.__version__, '1.11.0'):
|
| 57 |
+
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
|
| 58 |
+
else:
|
| 59 |
+
return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def reshape_classifier_output(model, n=1000):
|
| 63 |
+
# Update a TorchVision classification model to class count 'n' if required
|
| 64 |
+
from models.common import Classify
|
| 65 |
+
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
| 66 |
+
if isinstance(m, Classify): # YOLOv5 Classify() head
|
| 67 |
+
if m.linear.out_features != n:
|
| 68 |
+
m.linear = nn.Linear(m.linear.in_features, n)
|
| 69 |
+
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
| 70 |
+
if m.out_features != n:
|
| 71 |
+
setattr(model, name, nn.Linear(m.in_features, n))
|
| 72 |
+
elif isinstance(m, nn.Sequential):
|
| 73 |
+
types = [type(x) for x in m]
|
| 74 |
+
if nn.Linear in types:
|
| 75 |
+
i = types.index(nn.Linear) # nn.Linear index
|
| 76 |
+
if m[i].out_features != n:
|
| 77 |
+
m[i] = nn.Linear(m[i].in_features, n)
|
| 78 |
+
elif nn.Conv2d in types:
|
| 79 |
+
i = types.index(nn.Conv2d) # nn.Conv2d index
|
| 80 |
+
if m[i].out_channels != n:
|
| 81 |
+
m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@contextmanager
|
| 85 |
+
def torch_distributed_zero_first(local_rank: int):
|
| 86 |
+
# Decorator to make all processes in distributed training wait for each local_master to do something
|
| 87 |
+
if local_rank not in [-1, 0]:
|
| 88 |
+
dist.barrier(device_ids=[local_rank])
|
| 89 |
+
yield
|
| 90 |
+
if local_rank == 0:
|
| 91 |
+
dist.barrier(device_ids=[0])
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def device_count():
|
| 95 |
+
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
|
| 96 |
+
assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
|
| 97 |
+
try:
|
| 98 |
+
cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
|
| 99 |
+
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
|
| 100 |
+
except Exception:
|
| 101 |
+
return 0
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def select_device(device='', batch_size=0, newline=True):
|
| 105 |
+
# device = None or 'cpu' or 0 or '0' or '0,1,2,3'
|
| 106 |
+
s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
|
| 107 |
+
device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
|
| 108 |
+
cpu = device == 'cpu'
|
| 109 |
+
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
| 110 |
+
if cpu or mps:
|
| 111 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
| 112 |
+
elif device: # non-cpu device requested
|
| 113 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
|
| 114 |
+
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
|
| 115 |
+
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
|
| 116 |
+
|
| 117 |
+
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
| 118 |
+
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
| 119 |
+
n = len(devices) # device count
|
| 120 |
+
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
| 121 |
+
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
| 122 |
+
space = ' ' * (len(s) + 1)
|
| 123 |
+
for i, d in enumerate(devices):
|
| 124 |
+
p = torch.cuda.get_device_properties(i)
|
| 125 |
+
s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
|
| 126 |
+
arg = 'cuda:0'
|
| 127 |
+
elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
|
| 128 |
+
s += 'MPS\n'
|
| 129 |
+
arg = 'mps'
|
| 130 |
+
else: # revert to CPU
|
| 131 |
+
s += 'CPU\n'
|
| 132 |
+
arg = 'cpu'
|
| 133 |
+
|
| 134 |
+
if not newline:
|
| 135 |
+
s = s.rstrip()
|
| 136 |
+
LOGGER.info(s)
|
| 137 |
+
return torch.device(arg)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def time_sync():
|
| 141 |
+
# PyTorch-accurate time
|
| 142 |
+
if torch.cuda.is_available():
|
| 143 |
+
torch.cuda.synchronize()
|
| 144 |
+
return time.time()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def profile(input, ops, n=10, device=None):
|
| 148 |
+
""" YOLOv5 speed/memory/FLOPs profiler
|
| 149 |
+
Usage:
|
| 150 |
+
input = torch.randn(16, 3, 640, 640)
|
| 151 |
+
m1 = lambda x: x * torch.sigmoid(x)
|
| 152 |
+
m2 = nn.SiLU()
|
| 153 |
+
profile(input, [m1, m2], n=100) # profile over 100 iterations
|
| 154 |
+
"""
|
| 155 |
+
results = []
|
| 156 |
+
if not isinstance(device, torch.device):
|
| 157 |
+
device = select_device(device)
|
| 158 |
+
print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
| 159 |
+
f"{'input':>24s}{'output':>24s}")
|
| 160 |
+
|
| 161 |
+
for x in input if isinstance(input, list) else [input]:
|
| 162 |
+
x = x.to(device)
|
| 163 |
+
x.requires_grad = True
|
| 164 |
+
for m in ops if isinstance(ops, list) else [ops]:
|
| 165 |
+
m = m.to(device) if hasattr(m, 'to') else m # device
|
| 166 |
+
m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
|
| 167 |
+
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
|
| 168 |
+
try:
|
| 169 |
+
flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
|
| 170 |
+
except Exception:
|
| 171 |
+
flops = 0
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
for _ in range(n):
|
| 175 |
+
t[0] = time_sync()
|
| 176 |
+
y = m(x)
|
| 177 |
+
t[1] = time_sync()
|
| 178 |
+
try:
|
| 179 |
+
_ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
|
| 180 |
+
t[2] = time_sync()
|
| 181 |
+
except Exception: # no backward method
|
| 182 |
+
# print(e) # for debug
|
| 183 |
+
t[2] = float('nan')
|
| 184 |
+
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
| 185 |
+
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
| 186 |
+
mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
|
| 187 |
+
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
|
| 188 |
+
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
| 189 |
+
print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
|
| 190 |
+
results.append([p, flops, mem, tf, tb, s_in, s_out])
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(e)
|
| 193 |
+
results.append(None)
|
| 194 |
+
torch.cuda.empty_cache()
|
| 195 |
+
return results
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def is_parallel(model):
|
| 199 |
+
# Returns True if model is of type DP or DDP
|
| 200 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def de_parallel(model):
|
| 204 |
+
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
| 205 |
+
return model.module if is_parallel(model) else model
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def initialize_weights(model):
|
| 209 |
+
for m in model.modules():
|
| 210 |
+
t = type(m)
|
| 211 |
+
if t is nn.Conv2d:
|
| 212 |
+
pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 213 |
+
elif t is nn.BatchNorm2d:
|
| 214 |
+
m.eps = 1e-3
|
| 215 |
+
m.momentum = 0.03
|
| 216 |
+
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
| 217 |
+
m.inplace = True
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def find_modules(model, mclass=nn.Conv2d):
|
| 221 |
+
# Finds layer indices matching module class 'mclass'
|
| 222 |
+
return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def sparsity(model):
|
| 226 |
+
# Return global model sparsity
|
| 227 |
+
a, b = 0, 0
|
| 228 |
+
for p in model.parameters():
|
| 229 |
+
a += p.numel()
|
| 230 |
+
b += (p == 0).sum()
|
| 231 |
+
return b / a
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def prune(model, amount=0.3):
|
| 235 |
+
# Prune model to requested global sparsity
|
| 236 |
+
import torch.nn.utils.prune as prune
|
| 237 |
+
for name, m in model.named_modules():
|
| 238 |
+
if isinstance(m, nn.Conv2d):
|
| 239 |
+
prune.l1_unstructured(m, name='weight', amount=amount) # prune
|
| 240 |
+
prune.remove(m, 'weight') # make permanent
|
| 241 |
+
LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def fuse_conv_and_bn(conv, bn):
|
| 245 |
+
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
|
| 246 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
| 247 |
+
conv.out_channels,
|
| 248 |
+
kernel_size=conv.kernel_size,
|
| 249 |
+
stride=conv.stride,
|
| 250 |
+
padding=conv.padding,
|
| 251 |
+
dilation=conv.dilation,
|
| 252 |
+
groups=conv.groups,
|
| 253 |
+
bias=True).requires_grad_(False).to(conv.weight.device)
|
| 254 |
+
|
| 255 |
+
# Prepare filters
|
| 256 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
| 257 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
| 258 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
|
| 259 |
+
|
| 260 |
+
# Prepare spatial bias
|
| 261 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
| 262 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
| 263 |
+
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
|
| 264 |
+
|
| 265 |
+
return fusedconv
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def model_info(model, verbose=False, imgsz=640):
|
| 269 |
+
# Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
|
| 270 |
+
n_p = sum(x.numel() for x in model.parameters()) # number parameters
|
| 271 |
+
n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
|
| 272 |
+
if verbose:
|
| 273 |
+
print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
|
| 274 |
+
for i, (name, p) in enumerate(model.named_parameters()):
|
| 275 |
+
name = name.replace('module_list.', '')
|
| 276 |
+
print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
|
| 277 |
+
(i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
|
| 278 |
+
|
| 279 |
+
try: # FLOPs
|
| 280 |
+
p = next(model.parameters())
|
| 281 |
+
stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
|
| 282 |
+
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
| 283 |
+
flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
|
| 284 |
+
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
| 285 |
+
fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
|
| 286 |
+
except Exception:
|
| 287 |
+
fs = ''
|
| 288 |
+
|
| 289 |
+
name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
|
| 290 |
+
LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
|
| 294 |
+
# Scales img(bs,3,y,x) by ratio constrained to gs-multiple
|
| 295 |
+
if ratio == 1.0:
|
| 296 |
+
return img
|
| 297 |
+
h, w = img.shape[2:]
|
| 298 |
+
s = (int(h * ratio), int(w * ratio)) # new size
|
| 299 |
+
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
|
| 300 |
+
if not same_shape: # pad/crop img
|
| 301 |
+
h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
|
| 302 |
+
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def copy_attr(a, b, include=(), exclude=()):
|
| 306 |
+
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
| 307 |
+
for k, v in b.__dict__.items():
|
| 308 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
| 309 |
+
continue
|
| 310 |
+
else:
|
| 311 |
+
setattr(a, k, v)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
| 315 |
+
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
|
| 316 |
+
g = [], [], [] # optimizer parameter groups
|
| 317 |
+
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
| 318 |
+
#for v in model.modules():
|
| 319 |
+
# for p_name, p in v.named_parameters(recurse=0):
|
| 320 |
+
# if p_name == 'bias': # bias (no decay)
|
| 321 |
+
# g[2].append(p)
|
| 322 |
+
# elif p_name == 'weight' and isinstance(v, bn): # weight (no decay)
|
| 323 |
+
# g[1].append(p)
|
| 324 |
+
# else:
|
| 325 |
+
# g[0].append(p) # weight (with decay)
|
| 326 |
+
|
| 327 |
+
for v in model.modules():
|
| 328 |
+
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
|
| 329 |
+
g[2].append(v.bias)
|
| 330 |
+
if isinstance(v, bn): # weight (no decay)
|
| 331 |
+
g[1].append(v.weight)
|
| 332 |
+
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
|
| 333 |
+
g[0].append(v.weight)
|
| 334 |
+
|
| 335 |
+
if hasattr(v, 'im'):
|
| 336 |
+
if hasattr(v.im, 'implicit'):
|
| 337 |
+
g[1].append(v.im.implicit)
|
| 338 |
+
else:
|
| 339 |
+
for iv in v.im:
|
| 340 |
+
g[1].append(iv.implicit)
|
| 341 |
+
if hasattr(v, 'ia'):
|
| 342 |
+
if hasattr(v.ia, 'implicit'):
|
| 343 |
+
g[1].append(v.ia.implicit)
|
| 344 |
+
else:
|
| 345 |
+
for iv in v.ia:
|
| 346 |
+
g[1].append(iv.implicit)
|
| 347 |
+
|
| 348 |
+
if hasattr(v, 'im2'):
|
| 349 |
+
if hasattr(v.im2, 'implicit'):
|
| 350 |
+
g[1].append(v.im2.implicit)
|
| 351 |
+
else:
|
| 352 |
+
for iv in v.im2:
|
| 353 |
+
g[1].append(iv.implicit)
|
| 354 |
+
if hasattr(v, 'ia2'):
|
| 355 |
+
if hasattr(v.ia2, 'implicit'):
|
| 356 |
+
g[1].append(v.ia2.implicit)
|
| 357 |
+
else:
|
| 358 |
+
for iv in v.ia2:
|
| 359 |
+
g[1].append(iv.implicit)
|
| 360 |
+
|
| 361 |
+
if hasattr(v, 'im3'):
|
| 362 |
+
if hasattr(v.im3, 'implicit'):
|
| 363 |
+
g[1].append(v.im3.implicit)
|
| 364 |
+
else:
|
| 365 |
+
for iv in v.im3:
|
| 366 |
+
g[1].append(iv.implicit)
|
| 367 |
+
if hasattr(v, 'ia3'):
|
| 368 |
+
if hasattr(v.ia3, 'implicit'):
|
| 369 |
+
g[1].append(v.ia3.implicit)
|
| 370 |
+
else:
|
| 371 |
+
for iv in v.ia3:
|
| 372 |
+
g[1].append(iv.implicit)
|
| 373 |
+
|
| 374 |
+
if hasattr(v, 'im4'):
|
| 375 |
+
if hasattr(v.im4, 'implicit'):
|
| 376 |
+
g[1].append(v.im4.implicit)
|
| 377 |
+
else:
|
| 378 |
+
for iv in v.im4:
|
| 379 |
+
g[1].append(iv.implicit)
|
| 380 |
+
if hasattr(v, 'ia4'):
|
| 381 |
+
if hasattr(v.ia4, 'implicit'):
|
| 382 |
+
g[1].append(v.ia4.implicit)
|
| 383 |
+
else:
|
| 384 |
+
for iv in v.ia4:
|
| 385 |
+
g[1].append(iv.implicit)
|
| 386 |
+
|
| 387 |
+
if hasattr(v, 'im5'):
|
| 388 |
+
if hasattr(v.im5, 'implicit'):
|
| 389 |
+
g[1].append(v.im5.implicit)
|
| 390 |
+
else:
|
| 391 |
+
for iv in v.im5:
|
| 392 |
+
g[1].append(iv.implicit)
|
| 393 |
+
if hasattr(v, 'ia5'):
|
| 394 |
+
if hasattr(v.ia5, 'implicit'):
|
| 395 |
+
g[1].append(v.ia5.implicit)
|
| 396 |
+
else:
|
| 397 |
+
for iv in v.ia5:
|
| 398 |
+
g[1].append(iv.implicit)
|
| 399 |
+
|
| 400 |
+
if hasattr(v, 'im6'):
|
| 401 |
+
if hasattr(v.im6, 'implicit'):
|
| 402 |
+
g[1].append(v.im6.implicit)
|
| 403 |
+
else:
|
| 404 |
+
for iv in v.im6:
|
| 405 |
+
g[1].append(iv.implicit)
|
| 406 |
+
if hasattr(v, 'ia6'):
|
| 407 |
+
if hasattr(v.ia6, 'implicit'):
|
| 408 |
+
g[1].append(v.ia6.implicit)
|
| 409 |
+
else:
|
| 410 |
+
for iv in v.ia6:
|
| 411 |
+
g[1].append(iv.implicit)
|
| 412 |
+
|
| 413 |
+
if hasattr(v, 'im7'):
|
| 414 |
+
if hasattr(v.im7, 'implicit'):
|
| 415 |
+
g[1].append(v.im7.implicit)
|
| 416 |
+
else:
|
| 417 |
+
for iv in v.im7:
|
| 418 |
+
g[1].append(iv.implicit)
|
| 419 |
+
if hasattr(v, 'ia7'):
|
| 420 |
+
if hasattr(v.ia7, 'implicit'):
|
| 421 |
+
g[1].append(v.ia7.implicit)
|
| 422 |
+
else:
|
| 423 |
+
for iv in v.ia7:
|
| 424 |
+
g[1].append(iv.implicit)
|
| 425 |
+
|
| 426 |
+
if name == 'Adam':
|
| 427 |
+
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
|
| 428 |
+
elif name == 'AdamW':
|
| 429 |
+
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0, amsgrad=True)
|
| 430 |
+
elif name == 'RMSProp':
|
| 431 |
+
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
| 432 |
+
elif name == 'SGD':
|
| 433 |
+
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
| 434 |
+
elif name == 'LION':
|
| 435 |
+
optimizer = Lion(g[2], lr=lr, betas=(momentum, 0.99), weight_decay=0.0)
|
| 436 |
+
else:
|
| 437 |
+
raise NotImplementedError(f'Optimizer {name} not implemented.')
|
| 438 |
+
|
| 439 |
+
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
|
| 440 |
+
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
|
| 441 |
+
LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
|
| 442 |
+
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
|
| 443 |
+
return optimizer
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
|
| 447 |
+
# YOLOv5 torch.hub.load() wrapper with smart error/issue handling
|
| 448 |
+
if check_version(torch.__version__, '1.9.1'):
|
| 449 |
+
kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
|
| 450 |
+
if check_version(torch.__version__, '1.12.0'):
|
| 451 |
+
kwargs['trust_repo'] = True # argument required starting in torch 0.12
|
| 452 |
+
try:
|
| 453 |
+
return torch.hub.load(repo, model, **kwargs)
|
| 454 |
+
except Exception:
|
| 455 |
+
return torch.hub.load(repo, model, force_reload=True, **kwargs)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
|
| 459 |
+
# Resume training from a partially trained checkpoint
|
| 460 |
+
best_fitness = 0.0
|
| 461 |
+
start_epoch = ckpt['epoch'] + 1
|
| 462 |
+
if ckpt['optimizer'] is not None:
|
| 463 |
+
optimizer.load_state_dict(ckpt['optimizer']) # optimizer
|
| 464 |
+
best_fitness = ckpt['best_fitness']
|
| 465 |
+
if ema and ckpt.get('ema'):
|
| 466 |
+
ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
|
| 467 |
+
ema.updates = ckpt['updates']
|
| 468 |
+
if resume:
|
| 469 |
+
assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
|
| 470 |
+
f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
|
| 471 |
+
LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
|
| 472 |
+
if epochs < start_epoch:
|
| 473 |
+
LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
|
| 474 |
+
epochs += ckpt['epoch'] # finetune additional epochs
|
| 475 |
+
return best_fitness, start_epoch, epochs
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
class EarlyStopping:
|
| 479 |
+
# YOLOv5 simple early stopper
|
| 480 |
+
def __init__(self, patience=30):
|
| 481 |
+
self.best_fitness = 0.0 # i.e. mAP
|
| 482 |
+
self.best_epoch = 0
|
| 483 |
+
self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
|
| 484 |
+
self.possible_stop = False # possible stop may occur next epoch
|
| 485 |
+
|
| 486 |
+
def __call__(self, epoch, fitness):
|
| 487 |
+
if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
|
| 488 |
+
self.best_epoch = epoch
|
| 489 |
+
self.best_fitness = fitness
|
| 490 |
+
delta = epoch - self.best_epoch # epochs without improvement
|
| 491 |
+
self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
|
| 492 |
+
stop = delta >= self.patience # stop training if patience exceeded
|
| 493 |
+
if stop:
|
| 494 |
+
LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
|
| 495 |
+
f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
|
| 496 |
+
f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
|
| 497 |
+
f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
|
| 498 |
+
return stop
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class ModelEMA:
|
| 502 |
+
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
| 503 |
+
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
| 504 |
+
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
| 505 |
+
"""
|
| 506 |
+
|
| 507 |
+
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
| 508 |
+
# Create EMA
|
| 509 |
+
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
| 510 |
+
self.updates = updates # number of EMA updates
|
| 511 |
+
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
| 512 |
+
for p in self.ema.parameters():
|
| 513 |
+
p.requires_grad_(False)
|
| 514 |
+
|
| 515 |
+
def update(self, model):
|
| 516 |
+
# Update EMA parameters
|
| 517 |
+
self.updates += 1
|
| 518 |
+
d = self.decay(self.updates)
|
| 519 |
+
|
| 520 |
+
msd = de_parallel(model).state_dict() # model state_dict
|
| 521 |
+
for k, v in self.ema.state_dict().items():
|
| 522 |
+
if v.dtype.is_floating_point: # true for FP16 and FP32
|
| 523 |
+
v *= d
|
| 524 |
+
v += (1 - d) * msd[k].detach()
|
| 525 |
+
# assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
|
| 526 |
+
|
| 527 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
| 528 |
+
# Update EMA attributes
|
| 529 |
+
copy_attr(self.ema, model, include, exclude)
|