Vvaann commited on
Commit
d3d6f15
·
verified ·
1 Parent(s): e94074c

Upload 3 files

Browse files
Files changed (3) hide show
  1. general.py +1135 -0
  2. plots.py +570 -0
  3. torch_utils.py +529 -0
general.py ADDED
@@ -0,0 +1,1135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import glob
3
+ import inspect
4
+ import logging
5
+ import logging.config
6
+ import math
7
+ import os
8
+ import platform
9
+ import random
10
+ import re
11
+ import signal
12
+ import sys
13
+ import time
14
+ import urllib
15
+ from copy import deepcopy
16
+ from datetime import datetime
17
+ from itertools import repeat
18
+ from multiprocessing.pool import ThreadPool
19
+ from pathlib import Path
20
+ from subprocess import check_output
21
+ from tarfile import is_tarfile
22
+ from typing import Optional
23
+ from zipfile import ZipFile, is_zipfile
24
+
25
+ import cv2
26
+ import IPython
27
+ import numpy as np
28
+ import pandas as pd
29
+ import pkg_resources as pkg
30
+ import torch
31
+ import torchvision
32
+ import yaml
33
+
34
+ from utils import TryExcept, emojis
35
+ from utils.downloads import gsutil_getsize
36
+ from utils.metrics import box_iou, fitness
37
+
38
+ FILE = Path(__file__).resolve()
39
+ ROOT = FILE.parents[1] # YOLO root directory
40
+ RANK = int(os.getenv('RANK', -1))
41
+
42
+ # Settings
43
+ NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads
44
+ DATASETS_DIR = Path(os.getenv('YOLOv5_DATASETS_DIR', ROOT.parent / 'datasets')) # global datasets directory
45
+ AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode
46
+ VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode
47
+ TQDM_BAR_FORMAT = '{l_bar}{bar:10}| {n_fmt}/{total_fmt} {elapsed}' # tqdm bar format
48
+ FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf
49
+
50
+ torch.set_printoptions(linewidth=320, precision=5, profile='long')
51
+ np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5
52
+ pd.options.display.max_columns = 10
53
+ cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
54
+ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
55
+ os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
56
+
57
+
58
+ def is_ascii(s=''):
59
+ # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
60
+ s = str(s) # convert list, tuple, None, etc. to str
61
+ return len(s.encode().decode('ascii', 'ignore')) == len(s)
62
+
63
+
64
+ def is_chinese(s='人工智能'):
65
+ # Is string composed of any Chinese characters?
66
+ return bool(re.search('[\u4e00-\u9fff]', str(s)))
67
+
68
+
69
+ def is_colab():
70
+ # Is environment a Google Colab instance?
71
+ return 'google.colab' in sys.modules
72
+
73
+
74
+ def is_notebook():
75
+ # Is environment a Jupyter notebook? Verified on Colab, Jupyterlab, Kaggle, Paperspace
76
+ ipython_type = str(type(IPython.get_ipython()))
77
+ return 'colab' in ipython_type or 'zmqshell' in ipython_type
78
+
79
+
80
+ def is_kaggle():
81
+ # Is environment a Kaggle Notebook?
82
+ return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
83
+
84
+
85
+ def is_docker() -> bool:
86
+ """Check if the process runs inside a docker container."""
87
+ if Path("/.dockerenv").exists():
88
+ return True
89
+ try: # check if docker is in control groups
90
+ with open("/proc/self/cgroup") as file:
91
+ return any("docker" in line for line in file)
92
+ except OSError:
93
+ return False
94
+
95
+
96
+ def is_writeable(dir, test=False):
97
+ # Return True if directory has write permissions, test opening a file with write permissions if test=True
98
+ if not test:
99
+ return os.access(dir, os.W_OK) # possible issues on Windows
100
+ file = Path(dir) / 'tmp.txt'
101
+ try:
102
+ with open(file, 'w'): # open file with write permissions
103
+ pass
104
+ file.unlink() # remove file
105
+ return True
106
+ except OSError:
107
+ return False
108
+
109
+
110
+ LOGGING_NAME = "yolov5"
111
+
112
+
113
+ def set_logging(name=LOGGING_NAME, verbose=True):
114
+ # sets up logging for the given name
115
+ rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
116
+ level = logging.INFO if verbose and rank in {-1, 0} else logging.ERROR
117
+ logging.config.dictConfig({
118
+ "version": 1,
119
+ "disable_existing_loggers": False,
120
+ "formatters": {
121
+ name: {
122
+ "format": "%(message)s"}},
123
+ "handlers": {
124
+ name: {
125
+ "class": "logging.StreamHandler",
126
+ "formatter": name,
127
+ "level": level,}},
128
+ "loggers": {
129
+ name: {
130
+ "level": level,
131
+ "handlers": [name],
132
+ "propagate": False,}}})
133
+
134
+
135
+ set_logging(LOGGING_NAME) # run before defining LOGGER
136
+ LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
137
+ if platform.system() == 'Windows':
138
+ for fn in LOGGER.info, LOGGER.warning:
139
+ setattr(LOGGER, fn.__name__, lambda x: fn(emojis(x))) # emoji safe logging
140
+
141
+
142
+ def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'):
143
+ # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required.
144
+ env = os.getenv(env_var)
145
+ if env:
146
+ path = Path(env) # use environment variable
147
+ else:
148
+ cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs
149
+ path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir
150
+ path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable
151
+ path.mkdir(exist_ok=True) # make if required
152
+ return path
153
+
154
+
155
+ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
156
+
157
+
158
+ class Profile(contextlib.ContextDecorator):
159
+ # YOLO Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
160
+ def __init__(self, t=0.0):
161
+ self.t = t
162
+ self.cuda = torch.cuda.is_available()
163
+
164
+ def __enter__(self):
165
+ self.start = self.time()
166
+ return self
167
+
168
+ def __exit__(self, type, value, traceback):
169
+ self.dt = self.time() - self.start # delta-time
170
+ self.t += self.dt # accumulate dt
171
+
172
+ def time(self):
173
+ if self.cuda:
174
+ torch.cuda.synchronize()
175
+ return time.time()
176
+
177
+
178
+ class Timeout(contextlib.ContextDecorator):
179
+ # YOLO Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
180
+ def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
181
+ self.seconds = int(seconds)
182
+ self.timeout_message = timeout_msg
183
+ self.suppress = bool(suppress_timeout_errors)
184
+
185
+ def _timeout_handler(self, signum, frame):
186
+ raise TimeoutError(self.timeout_message)
187
+
188
+ def __enter__(self):
189
+ if platform.system() != 'Windows': # not supported on Windows
190
+ signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM
191
+ signal.alarm(self.seconds) # start countdown for SIGALRM to be raised
192
+
193
+ def __exit__(self, exc_type, exc_val, exc_tb):
194
+ if platform.system() != 'Windows':
195
+ signal.alarm(0) # Cancel SIGALRM if it's scheduled
196
+ if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError
197
+ return True
198
+
199
+
200
+ class WorkingDirectory(contextlib.ContextDecorator):
201
+ # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager
202
+ def __init__(self, new_dir):
203
+ self.dir = new_dir # new dir
204
+ self.cwd = Path.cwd().resolve() # current dir
205
+
206
+ def __enter__(self):
207
+ os.chdir(self.dir)
208
+
209
+ def __exit__(self, exc_type, exc_val, exc_tb):
210
+ os.chdir(self.cwd)
211
+
212
+
213
+ def methods(instance):
214
+ # Get class/instance methods
215
+ return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
216
+
217
+
218
+ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
219
+ # Print function arguments (optional args dict)
220
+ x = inspect.currentframe().f_back # previous frame
221
+ file, _, func, _, _ = inspect.getframeinfo(x)
222
+ if args is None: # get args automatically
223
+ args, _, _, frm = inspect.getargvalues(x)
224
+ args = {k: v for k, v in frm.items() if k in args}
225
+ try:
226
+ file = Path(file).resolve().relative_to(ROOT).with_suffix('')
227
+ except ValueError:
228
+ file = Path(file).stem
229
+ s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
230
+ LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
231
+
232
+
233
+ def init_seeds(seed=0, deterministic=False):
234
+ # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
235
+ random.seed(seed)
236
+ np.random.seed(seed)
237
+ torch.manual_seed(seed)
238
+ torch.cuda.manual_seed(seed)
239
+ torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
240
+ # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
241
+ if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
242
+ torch.use_deterministic_algorithms(True)
243
+ torch.backends.cudnn.deterministic = True
244
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
245
+ os.environ['PYTHONHASHSEED'] = str(seed)
246
+
247
+
248
+ def intersect_dicts(da, db, exclude=()):
249
+ # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
250
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
251
+
252
+
253
+ def get_default_args(func):
254
+ # Get func() default arguments
255
+ signature = inspect.signature(func)
256
+ return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
257
+
258
+
259
+ def get_latest_run(search_dir='.'):
260
+ # Return path to most recent 'last.pt' in /runs (i.e. to --resume from)
261
+ last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True)
262
+ return max(last_list, key=os.path.getctime) if last_list else ''
263
+
264
+
265
+ def file_age(path=__file__):
266
+ # Return days since last file update
267
+ dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta
268
+ return dt.days # + dt.seconds / 86400 # fractional days
269
+
270
+
271
+ def file_date(path=__file__):
272
+ # Return human-readable file modification date, i.e. '2021-3-26'
273
+ t = datetime.fromtimestamp(Path(path).stat().st_mtime)
274
+ return f'{t.year}-{t.month}-{t.day}'
275
+
276
+
277
+ def file_size(path):
278
+ # Return file/dir size (MB)
279
+ mb = 1 << 20 # bytes to MiB (1024 ** 2)
280
+ path = Path(path)
281
+ if path.is_file():
282
+ return path.stat().st_size / mb
283
+ elif path.is_dir():
284
+ return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb
285
+ else:
286
+ return 0.0
287
+
288
+
289
+ def check_online():
290
+ # Check internet connectivity
291
+ import socket
292
+
293
+ def run_once():
294
+ # Check once
295
+ try:
296
+ socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility
297
+ return True
298
+ except OSError:
299
+ return False
300
+
301
+ return run_once() or run_once() # check twice to increase robustness to intermittent connectivity issues
302
+
303
+
304
+ def git_describe(path=ROOT): # path must be a directory
305
+ # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe
306
+ try:
307
+ assert (Path(path) / '.git').is_dir()
308
+ return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1]
309
+ except Exception:
310
+ return ''
311
+
312
+
313
+ @TryExcept()
314
+ @WorkingDirectory(ROOT)
315
+ def check_git_status(repo='WongKinYiu/yolov9', branch='main'):
316
+ # YOLO status check, recommend 'git pull' if code is out of date
317
+ url = f'https://github.com/{repo}'
318
+ msg = f', for updates see {url}'
319
+ s = colorstr('github: ') # string
320
+ assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg
321
+ assert check_online(), s + 'skipping check (offline)' + msg
322
+
323
+ splits = re.split(pattern=r'\s', string=check_output('git remote -v', shell=True).decode())
324
+ matches = [repo in s for s in splits]
325
+ if any(matches):
326
+ remote = splits[matches.index(True) - 1]
327
+ else:
328
+ remote = 'ultralytics'
329
+ check_output(f'git remote add {remote} {url}', shell=True)
330
+ check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
331
+ local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
332
+ n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
333
+ if n > 0:
334
+ pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
335
+ s += f"⚠️ YOLO is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
336
+ else:
337
+ s += f'up to date with {url} ✅'
338
+ LOGGER.info(s)
339
+
340
+
341
+ @WorkingDirectory(ROOT)
342
+ def check_git_info(path='.'):
343
+ # YOLO git info check, return {remote, branch, commit}
344
+ check_requirements('gitpython')
345
+ import git
346
+ try:
347
+ repo = git.Repo(path)
348
+ remote = repo.remotes.origin.url.replace('.git', '') # i.e. 'https://github.com/WongKinYiu/yolov9'
349
+ commit = repo.head.commit.hexsha # i.e. '3134699c73af83aac2a481435550b968d5792c0d'
350
+ try:
351
+ branch = repo.active_branch.name # i.e. 'main'
352
+ except TypeError: # not on any branch
353
+ branch = None # i.e. 'detached HEAD' state
354
+ return {'remote': remote, 'branch': branch, 'commit': commit}
355
+ except git.exc.InvalidGitRepositoryError: # path is not a git dir
356
+ return {'remote': None, 'branch': None, 'commit': None}
357
+
358
+
359
+ def check_python(minimum='3.7.0'):
360
+ # Check current python version vs. required python version
361
+ check_version(platform.python_version(), minimum, name='Python ', hard=True)
362
+
363
+
364
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
365
+ # Check version vs. required version
366
+ current, minimum = (pkg.parse_version(x) for x in (current, minimum))
367
+ result = (current == minimum) if pinned else (current >= minimum) # bool
368
+ s = f'WARNING ⚠️ {name}{minimum} is required by YOLO, but {name}{current} is currently installed' # string
369
+ if hard:
370
+ assert result, emojis(s) # assert min requirements met
371
+ if verbose and not result:
372
+ LOGGER.warning(s)
373
+ return result
374
+
375
+
376
+ @TryExcept()
377
+ def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=''):
378
+ # Check installed dependencies meet YOLO requirements (pass *.txt file or list of packages or single package str)
379
+ prefix = colorstr('red', 'bold', 'requirements:')
380
+ check_python() # check python version
381
+ if isinstance(requirements, Path): # requirements.txt file
382
+ file = requirements.resolve()
383
+ assert file.exists(), f"{prefix} {file} not found, check failed."
384
+ with file.open() as f:
385
+ requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude]
386
+ elif isinstance(requirements, str):
387
+ requirements = [requirements]
388
+
389
+ s = ''
390
+ n = 0
391
+ for r in requirements:
392
+ try:
393
+ pkg.require(r)
394
+ except (pkg.VersionConflict, pkg.DistributionNotFound): # exception if requirements not met
395
+ s += f'"{r}" '
396
+ n += 1
397
+
398
+ if s and install and AUTOINSTALL: # check environment variable
399
+ LOGGER.info(f"{prefix} YOLO requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
400
+ try:
401
+ # assert check_online(), "AutoUpdate skipped (offline)"
402
+ LOGGER.info(check_output(f'pip install {s} {cmds}', shell=True).decode())
403
+ source = file if 'file' in locals() else requirements
404
+ s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \
405
+ f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
406
+ LOGGER.info(s)
407
+ except Exception as e:
408
+ LOGGER.warning(f'{prefix} ❌ {e}')
409
+
410
+
411
+ def check_img_size(imgsz, s=32, floor=0):
412
+ # Verify image size is a multiple of stride s in each dimension
413
+ if isinstance(imgsz, int): # integer i.e. img_size=640
414
+ new_size = max(make_divisible(imgsz, int(s)), floor)
415
+ else: # list i.e. img_size=[640, 480]
416
+ imgsz = list(imgsz) # convert to list if tuple
417
+ new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz]
418
+ if new_size != imgsz:
419
+ LOGGER.warning(f'WARNING ⚠️ --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}')
420
+ return new_size
421
+
422
+
423
+ def check_imshow(warn=False):
424
+ # Check if environment supports image displays
425
+ try:
426
+ assert not is_notebook()
427
+ assert not is_docker()
428
+ cv2.imshow('test', np.zeros((1, 1, 3)))
429
+ cv2.waitKey(1)
430
+ cv2.destroyAllWindows()
431
+ cv2.waitKey(1)
432
+ return True
433
+ except Exception as e:
434
+ if warn:
435
+ LOGGER.warning(f'WARNING ⚠️ Environment does not support cv2.imshow() or PIL Image.show()\n{e}')
436
+ return False
437
+
438
+
439
+ def check_suffix(file='yolo.pt', suffix=('.pt',), msg=''):
440
+ # Check file(s) for acceptable suffix
441
+ if file and suffix:
442
+ if isinstance(suffix, str):
443
+ suffix = [suffix]
444
+ for f in file if isinstance(file, (list, tuple)) else [file]:
445
+ s = Path(f).suffix.lower() # file suffix
446
+ if len(s):
447
+ assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}"
448
+
449
+
450
+ def check_yaml(file, suffix=('.yaml', '.yml')):
451
+ # Search/download YAML file (if necessary) and return path, checking suffix
452
+ return check_file(file, suffix)
453
+
454
+
455
+ def check_file(file, suffix=''):
456
+ # Search/download file (if necessary) and return path
457
+ check_suffix(file, suffix) # optional
458
+ file = str(file) # convert to str()
459
+ if os.path.isfile(file) or not file: # exists
460
+ return file
461
+ elif file.startswith(('http:/', 'https:/')): # download
462
+ url = file # warning: Pathlib turns :// -> :/
463
+ file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth
464
+ if os.path.isfile(file):
465
+ LOGGER.info(f'Found {url} locally at {file}') # file already exists
466
+ else:
467
+ LOGGER.info(f'Downloading {url} to {file}...')
468
+ torch.hub.download_url_to_file(url, file)
469
+ assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check
470
+ return file
471
+ elif file.startswith('clearml://'): # ClearML Dataset ID
472
+ assert 'clearml' in sys.modules, "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
473
+ return file
474
+ else: # search
475
+ files = []
476
+ for d in 'data', 'models', 'utils': # search directories
477
+ files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file
478
+ assert len(files), f'File not found: {file}' # assert file was found
479
+ assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique
480
+ return files[0] # return file
481
+
482
+
483
+ def check_font(font=FONT, progress=False):
484
+ # Download font to CONFIG_DIR if necessary
485
+ font = Path(font)
486
+ file = CONFIG_DIR / font.name
487
+ if not font.exists() and not file.exists():
488
+ url = f'https://ultralytics.com/assets/{font.name}'
489
+ LOGGER.info(f'Downloading {url} to {file}...')
490
+ torch.hub.download_url_to_file(url, str(file), progress=progress)
491
+
492
+
493
+ def check_dataset(data, autodownload=True):
494
+ # Download, check and/or unzip dataset if not found locally
495
+
496
+ # Download (optional)
497
+ extract_dir = ''
498
+ if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)):
499
+ download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1)
500
+ data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml'))
501
+ extract_dir, autodownload = data.parent, False
502
+
503
+ # Read yaml (optional)
504
+ if isinstance(data, (str, Path)):
505
+ data = yaml_load(data) # dictionary
506
+
507
+ # Checks
508
+ for k in 'train', 'val', 'names':
509
+ assert k in data, emojis(f"data.yaml '{k}:' field missing ❌")
510
+ if isinstance(data['names'], (list, tuple)): # old array format
511
+ data['names'] = dict(enumerate(data['names'])) # convert to dict
512
+ assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car'
513
+ data['nc'] = len(data['names'])
514
+
515
+ # Resolve paths
516
+ path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
517
+ if not path.is_absolute():
518
+ path = (ROOT / path).resolve()
519
+ data['path'] = path # download scripts
520
+ for k in 'train', 'val', 'test':
521
+ if data.get(k): # prepend path
522
+ if isinstance(data[k], str):
523
+ x = (path / data[k]).resolve()
524
+ if not x.exists() and data[k].startswith('../'):
525
+ x = (path / data[k][3:]).resolve()
526
+ data[k] = str(x)
527
+ else:
528
+ data[k] = [str((path / x).resolve()) for x in data[k]]
529
+
530
+ # Parse yaml
531
+ train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))
532
+ if val:
533
+ val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
534
+ if not all(x.exists() for x in val):
535
+ LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()])
536
+ if not s or not autodownload:
537
+ raise Exception('Dataset not found ❌')
538
+ t = time.time()
539
+ if s.startswith('http') and s.endswith('.zip'): # URL
540
+ f = Path(s).name # filename
541
+ LOGGER.info(f'Downloading {s} to {f}...')
542
+ torch.hub.download_url_to_file(s, f)
543
+ Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root
544
+ unzip_file(f, path=DATASETS_DIR) # unzip
545
+ Path(f).unlink() # remove zip
546
+ r = None # success
547
+ elif s.startswith('bash '): # bash script
548
+ LOGGER.info(f'Running {s} ...')
549
+ r = os.system(s)
550
+ else: # python script
551
+ r = exec(s, {'yaml': data}) # return None
552
+ dt = f'({round(time.time() - t, 1)}s)'
553
+ s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
554
+ LOGGER.info(f"Dataset download {s}")
555
+ check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts
556
+ return data # dictionary
557
+
558
+
559
+ def check_amp(model):
560
+ # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation
561
+ from models.common import AutoShape, DetectMultiBackend
562
+
563
+ def amp_allclose(model, im):
564
+ # All close FP32 vs AMP results
565
+ m = AutoShape(model, verbose=False) # model
566
+ a = m(im).xywhn[0] # FP32 inference
567
+ m.amp = True
568
+ b = m(im).xywhn[0] # AMP inference
569
+ return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance
570
+
571
+ prefix = colorstr('AMP: ')
572
+ device = next(model.parameters()).device # get model device
573
+ if device.type in ('cpu', 'mps'):
574
+ return False # AMP only used on CUDA devices
575
+ f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
576
+ im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
577
+ try:
578
+ #assert amp_allclose(deepcopy(model), im) or amp_allclose(DetectMultiBackend('yolo.pt', device), im)
579
+ LOGGER.info(f'{prefix}checks passed ✅')
580
+ return True
581
+ except Exception:
582
+ help_url = 'https://github.com/ultralytics/yolov5/issues/7908'
583
+ LOGGER.warning(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')
584
+ return False
585
+
586
+
587
+ def yaml_load(file='data.yaml'):
588
+ # Single-line safe yaml loading
589
+ with open(file, errors='ignore') as f:
590
+ return yaml.safe_load(f)
591
+
592
+
593
+ def yaml_save(file='data.yaml', data={}):
594
+ # Single-line safe yaml saving
595
+ with open(file, 'w') as f:
596
+ yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
597
+
598
+
599
+ def unzip_file(file, path=None, exclude=('.DS_Store', '__MACOSX')):
600
+ # Unzip a *.zip file to path/, excluding files containing strings in exclude list
601
+ if path is None:
602
+ path = Path(file).parent # default path
603
+ with ZipFile(file) as zipObj:
604
+ for f in zipObj.namelist(): # list all archived filenames in the zip
605
+ if all(x not in f for x in exclude):
606
+ zipObj.extract(f, path=path)
607
+
608
+
609
+ def url2file(url):
610
+ # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
611
+ url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
612
+ return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth
613
+
614
+
615
+ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
616
+ # Multithreaded file download and unzip function, used in data.yaml for autodownload
617
+ def download_one(url, dir):
618
+ # Download 1 file
619
+ success = True
620
+ if os.path.isfile(url):
621
+ f = Path(url) # filename
622
+ else: # does not exist
623
+ f = dir / Path(url).name
624
+ LOGGER.info(f'Downloading {url} to {f}...')
625
+ for i in range(retry + 1):
626
+ if curl:
627
+ s = 'sS' if threads > 1 else '' # silent
628
+ r = os.system(
629
+ f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
630
+ success = r == 0
631
+ else:
632
+ torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
633
+ success = f.is_file()
634
+ if success:
635
+ break
636
+ elif i < retry:
637
+ LOGGER.warning(f'⚠️ Download failure, retrying {i + 1}/{retry} {url}...')
638
+ else:
639
+ LOGGER.warning(f'❌ Failed to download {url}...')
640
+
641
+ if unzip and success and (f.suffix == '.gz' or is_zipfile(f) or is_tarfile(f)):
642
+ LOGGER.info(f'Unzipping {f}...')
643
+ if is_zipfile(f):
644
+ unzip_file(f, dir) # unzip
645
+ elif is_tarfile(f):
646
+ os.system(f'tar xf {f} --directory {f.parent}') # unzip
647
+ elif f.suffix == '.gz':
648
+ os.system(f'tar xfz {f} --directory {f.parent}') # unzip
649
+ if delete:
650
+ f.unlink() # remove zip
651
+
652
+ dir = Path(dir)
653
+ dir.mkdir(parents=True, exist_ok=True) # make directory
654
+ if threads > 1:
655
+ pool = ThreadPool(threads)
656
+ pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
657
+ pool.close()
658
+ pool.join()
659
+ else:
660
+ for u in [url] if isinstance(url, (str, Path)) else url:
661
+ download_one(u, dir)
662
+
663
+
664
+ def make_divisible(x, divisor):
665
+ # Returns nearest x divisible by divisor
666
+ if isinstance(divisor, torch.Tensor):
667
+ divisor = int(divisor.max()) # to int
668
+ return math.ceil(x / divisor) * divisor
669
+
670
+
671
+ def clean_str(s):
672
+ # Cleans a string by replacing special characters with underscore _
673
+ return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
674
+
675
+
676
+ def one_cycle(y1=0.0, y2=1.0, steps=100):
677
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
678
+ return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
679
+
680
+
681
+ def one_flat_cycle(y1=0.0, y2=1.0, steps=100):
682
+ # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf
683
+ #return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1
684
+ return lambda x: ((1 - math.cos((x - (steps // 2)) * math.pi / (steps // 2))) / 2) * (y2 - y1) + y1 if (x > (steps // 2)) else y1
685
+
686
+
687
+ def colorstr(*input):
688
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
689
+ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
690
+ colors = {
691
+ 'black': '\033[30m', # basic colors
692
+ 'red': '\033[31m',
693
+ 'green': '\033[32m',
694
+ 'yellow': '\033[33m',
695
+ 'blue': '\033[34m',
696
+ 'magenta': '\033[35m',
697
+ 'cyan': '\033[36m',
698
+ 'white': '\033[37m',
699
+ 'bright_black': '\033[90m', # bright colors
700
+ 'bright_red': '\033[91m',
701
+ 'bright_green': '\033[92m',
702
+ 'bright_yellow': '\033[93m',
703
+ 'bright_blue': '\033[94m',
704
+ 'bright_magenta': '\033[95m',
705
+ 'bright_cyan': '\033[96m',
706
+ 'bright_white': '\033[97m',
707
+ 'end': '\033[0m', # misc
708
+ 'bold': '\033[1m',
709
+ 'underline': '\033[4m'}
710
+ return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
711
+
712
+
713
+ def labels_to_class_weights(labels, nc=80):
714
+ # Get class weights (inverse frequency) from training labels
715
+ if labels[0] is None: # no labels loaded
716
+ return torch.Tensor()
717
+
718
+ labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
719
+ classes = labels[:, 0].astype(int) # labels = [class xywh]
720
+ weights = np.bincount(classes, minlength=nc) # occurrences per class
721
+
722
+ # Prepend gridpoint count (for uCE training)
723
+ # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
724
+ # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
725
+
726
+ weights[weights == 0] = 1 # replace empty bins with 1
727
+ weights = 1 / weights # number of targets per class
728
+ weights /= weights.sum() # normalize
729
+ return torch.from_numpy(weights).float()
730
+
731
+
732
+ def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
733
+ # Produces image weights based on class_weights and image contents
734
+ # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
735
+ class_counts = np.array([np.bincount(x[:, 0].astype(int), minlength=nc) for x in labels])
736
+ return (class_weights.reshape(1, nc) * class_counts).sum(1)
737
+
738
+
739
+ def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper)
740
+ # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
741
+ # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')
742
+ # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')
743
+ # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco
744
+ # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet
745
+ return [
746
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,
747
+ 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,
748
+ 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
749
+
750
+
751
+ def xyxy2xywh(x):
752
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
753
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
754
+ y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center
755
+ y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center
756
+ y[..., 2] = x[..., 2] - x[..., 0] # width
757
+ y[..., 3] = x[..., 3] - x[..., 1] # height
758
+ return y
759
+
760
+
761
+ def xywh2xyxy(x):
762
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
763
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
764
+ y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
765
+ y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
766
+ y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
767
+ y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
768
+ return y
769
+
770
+
771
+ def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):
772
+ # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
773
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
774
+ y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x
775
+ y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y
776
+ y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x
777
+ y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y
778
+ return y
779
+
780
+
781
+ def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0):
782
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
783
+ if clip:
784
+ clip_boxes(x, (h - eps, w - eps)) # warning: inplace clip
785
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
786
+ y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center
787
+ y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center
788
+ y[..., 2] = (x[..., 2] - x[..., 0]) / w # width
789
+ y[..., 3] = (x[..., 3] - x[..., 1]) / h # height
790
+ return y
791
+
792
+
793
+ def xyn2xy(x, w=640, h=640, padw=0, padh=0):
794
+ # Convert normalized segments into pixel segments, shape (n,2)
795
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
796
+ y[..., 0] = w * x[..., 0] + padw # top left x
797
+ y[..., 1] = h * x[..., 1] + padh # top left y
798
+ return y
799
+
800
+
801
+ def segment2box(segment, width=640, height=640):
802
+ # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy)
803
+ x, y = segment.T # segment xy
804
+ inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
805
+ x, y, = x[inside], y[inside]
806
+ return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy
807
+
808
+
809
+ def segments2boxes(segments):
810
+ # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh)
811
+ boxes = []
812
+ for s in segments:
813
+ x, y = s.T # segment xy
814
+ boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
815
+ return xyxy2xywh(np.array(boxes)) # cls, xywh
816
+
817
+
818
+ def resample_segments(segments, n=1000):
819
+ # Up-sample an (n,2) segment
820
+ for i, s in enumerate(segments):
821
+ s = np.concatenate((s, s[0:1, :]), axis=0)
822
+ x = np.linspace(0, len(s) - 1, n)
823
+ xp = np.arange(len(s))
824
+ segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy
825
+ return segments
826
+
827
+
828
+ def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None):
829
+ # Rescale boxes (xyxy) from img1_shape to img0_shape
830
+ if ratio_pad is None: # calculate from img0_shape
831
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
832
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
833
+ else:
834
+ gain = ratio_pad[0][0]
835
+ pad = ratio_pad[1]
836
+
837
+ boxes[:, [0, 2]] -= pad[0] # x padding
838
+ boxes[:, [1, 3]] -= pad[1] # y padding
839
+ boxes[:, :4] /= gain
840
+ clip_boxes(boxes, img0_shape)
841
+ return boxes
842
+
843
+
844
+ def scale_segments(img1_shape, segments, img0_shape, ratio_pad=None, normalize=False):
845
+ # Rescale coords (xyxy) from img1_shape to img0_shape
846
+ if ratio_pad is None: # calculate from img0_shape
847
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
848
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
849
+ else:
850
+ gain = ratio_pad[0][0]
851
+ pad = ratio_pad[1]
852
+
853
+ segments[:, 0] -= pad[0] # x padding
854
+ segments[:, 1] -= pad[1] # y padding
855
+ segments /= gain
856
+ clip_segments(segments, img0_shape)
857
+ if normalize:
858
+ segments[:, 0] /= img0_shape[1] # width
859
+ segments[:, 1] /= img0_shape[0] # height
860
+ return segments
861
+
862
+
863
+ def clip_boxes(boxes, shape):
864
+ # Clip boxes (xyxy) to image shape (height, width)
865
+ if isinstance(boxes, torch.Tensor): # faster individually
866
+ boxes[:, 0].clamp_(0, shape[1]) # x1
867
+ boxes[:, 1].clamp_(0, shape[0]) # y1
868
+ boxes[:, 2].clamp_(0, shape[1]) # x2
869
+ boxes[:, 3].clamp_(0, shape[0]) # y2
870
+ else: # np.array (faster grouped)
871
+ boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2
872
+ boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2
873
+
874
+
875
+ def clip_segments(segments, shape):
876
+ # Clip segments (xy1,xy2,...) to image shape (height, width)
877
+ if isinstance(segments, torch.Tensor): # faster individually
878
+ segments[:, 0].clamp_(0, shape[1]) # x
879
+ segments[:, 1].clamp_(0, shape[0]) # y
880
+ else: # np.array (faster grouped)
881
+ segments[:, 0] = segments[:, 0].clip(0, shape[1]) # x
882
+ segments[:, 1] = segments[:, 1].clip(0, shape[0]) # y
883
+
884
+
885
+ def non_max_suppression(
886
+ prediction,
887
+ conf_thres=0.25,
888
+ iou_thres=0.45,
889
+ classes=None,
890
+ agnostic=False,
891
+ multi_label=False,
892
+ labels=(),
893
+ max_det=300,
894
+ nm=0, # number of masks
895
+ ):
896
+ """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
897
+
898
+ Returns:
899
+ list of detections, on (n,6) tensor per image [xyxy, conf, cls]
900
+ """
901
+
902
+ if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out)
903
+ prediction = prediction[0] # select only inference output
904
+
905
+ device = prediction.device
906
+ mps = 'mps' in device.type # Apple MPS
907
+ if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
908
+ prediction = prediction.cpu()
909
+ bs = prediction.shape[0] # batch size
910
+ nc = prediction.shape[1] - nm - 4 # number of classes
911
+ mi = 4 + nc # mask start index
912
+ xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
913
+
914
+ # Checks
915
+ assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
916
+ assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
917
+
918
+ # Settings
919
+ # min_wh = 2 # (pixels) minimum box width and height
920
+ max_wh = 7680 # (pixels) maximum box width and height
921
+ max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
922
+ time_limit = 2.5 + 0.05 * bs # seconds to quit after
923
+ redundant = True # require redundant detections
924
+ multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
925
+ merge = False # use merge-NMS
926
+
927
+ t = time.time()
928
+ output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
929
+ for xi, x in enumerate(prediction): # image index, image inference
930
+ # Apply constraints
931
+ # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
932
+ x = x.T[xc[xi]] # confidence
933
+
934
+ # Cat apriori labels if autolabelling
935
+ if labels and len(labels[xi]):
936
+ lb = labels[xi]
937
+ v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
938
+ v[:, :4] = lb[:, 1:5] # box
939
+ v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
940
+ x = torch.cat((x, v), 0)
941
+
942
+ # If none remain process next image
943
+ if not x.shape[0]:
944
+ continue
945
+
946
+ # Detections matrix nx6 (xyxy, conf, cls)
947
+ box, cls, mask = x.split((4, nc, nm), 1)
948
+ box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
949
+ if multi_label:
950
+ i, j = (cls > conf_thres).nonzero(as_tuple=False).T
951
+ x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
952
+ else: # best class only
953
+ conf, j = cls.max(1, keepdim=True)
954
+ x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
955
+
956
+ # Filter by class
957
+ if classes is not None:
958
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
959
+
960
+ # Apply finite constraint
961
+ # if not torch.isfinite(x).all():
962
+ # x = x[torch.isfinite(x).all(1)]
963
+
964
+ # Check shape
965
+ n = x.shape[0] # number of boxes
966
+ if not n: # no boxes
967
+ continue
968
+ elif n > max_nms: # excess boxes
969
+ x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
970
+ else:
971
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
972
+
973
+ # Batched NMS
974
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
975
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
976
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
977
+ if i.shape[0] > max_det: # limit detections
978
+ i = i[:max_det]
979
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
980
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
981
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
982
+ weights = iou * scores[None] # box weights
983
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
984
+ if redundant:
985
+ i = i[iou.sum(1) > 1] # require redundancy
986
+
987
+ output[xi] = x[i]
988
+ if mps:
989
+ output[xi] = output[xi].to(device)
990
+ if (time.time() - t) > time_limit:
991
+ LOGGER.warning(f'WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded')
992
+ break # time limit exceeded
993
+
994
+ return output
995
+
996
+
997
+ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
998
+ # Strip optimizer from 'f' to finalize training, optionally save as 's'
999
+ x = torch.load(f, map_location=torch.device('cpu'))
1000
+ if x.get('ema'):
1001
+ x['model'] = x['ema'] # replace model with ema
1002
+ for k in 'optimizer', 'best_fitness', 'ema', 'updates': # keys
1003
+ x[k] = None
1004
+ x['epoch'] = -1
1005
+ x['model'].half() # to FP16
1006
+ for p in x['model'].parameters():
1007
+ p.requires_grad = False
1008
+ torch.save(x, s or f)
1009
+ mb = os.path.getsize(s or f) / 1E6 # filesize
1010
+ LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
1011
+
1012
+
1013
+ def print_mutation(keys, results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')):
1014
+ evolve_csv = save_dir / 'evolve.csv'
1015
+ evolve_yaml = save_dir / 'hyp_evolve.yaml'
1016
+ keys = tuple(keys) + tuple(hyp.keys()) # [results + hyps]
1017
+ keys = tuple(x.strip() for x in keys)
1018
+ vals = results + tuple(hyp.values())
1019
+ n = len(keys)
1020
+
1021
+ # Download (optional)
1022
+ if bucket:
1023
+ url = f'gs://{bucket}/evolve.csv'
1024
+ if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0):
1025
+ os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local
1026
+
1027
+ # Log to evolve.csv
1028
+ s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header
1029
+ with open(evolve_csv, 'a') as f:
1030
+ f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n')
1031
+
1032
+ # Save yaml
1033
+ with open(evolve_yaml, 'w') as f:
1034
+ data = pd.read_csv(evolve_csv)
1035
+ data = data.rename(columns=lambda x: x.strip()) # strip keys
1036
+ i = np.argmax(fitness(data.values[:, :4])) #
1037
+ generations = len(data)
1038
+ f.write('# YOLO Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' +
1039
+ f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) +
1040
+ '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n')
1041
+ yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False)
1042
+
1043
+ # Print to screen
1044
+ LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix +
1045
+ ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}'
1046
+ for x in vals) + '\n\n')
1047
+
1048
+ if bucket:
1049
+ os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload
1050
+
1051
+
1052
+ def apply_classifier(x, model, img, im0):
1053
+ # Apply a second stage classifier to YOLO outputs
1054
+ # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
1055
+ im0 = [im0] if isinstance(im0, np.ndarray) else im0
1056
+ for i, d in enumerate(x): # per image
1057
+ if d is not None and len(d):
1058
+ d = d.clone()
1059
+
1060
+ # Reshape and pad cutouts
1061
+ b = xyxy2xywh(d[:, :4]) # boxes
1062
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square
1063
+ b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad
1064
+ d[:, :4] = xywh2xyxy(b).long()
1065
+
1066
+ # Rescale boxes from img_size to im0 size
1067
+ scale_boxes(img.shape[2:], d[:, :4], im0[i].shape)
1068
+
1069
+ # Classes
1070
+ pred_cls1 = d[:, 5].long()
1071
+ ims = []
1072
+ for a in d:
1073
+ cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]
1074
+ im = cv2.resize(cutout, (224, 224)) # BGR
1075
+
1076
+ im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
1077
+ im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32
1078
+ im /= 255 # 0 - 255 to 0.0 - 1.0
1079
+ ims.append(im)
1080
+
1081
+ pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction
1082
+ x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections
1083
+
1084
+ return x
1085
+
1086
+
1087
+ def increment_path(path, exist_ok=False, sep='', mkdir=False):
1088
+ # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
1089
+ path = Path(path) # os-agnostic
1090
+ if path.exists() and not exist_ok:
1091
+ path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '')
1092
+
1093
+ # Method 1
1094
+ for n in range(2, 9999):
1095
+ p = f'{path}{sep}{n}{suffix}' # increment path
1096
+ if not os.path.exists(p): #
1097
+ break
1098
+ path = Path(p)
1099
+
1100
+ # Method 2 (deprecated)
1101
+ # dirs = glob.glob(f"{path}{sep}*") # similar paths
1102
+ # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs]
1103
+ # i = [int(m.groups()[0]) for m in matches if m] # indices
1104
+ # n = max(i) + 1 if i else 2 # increment number
1105
+ # path = Path(f"{path}{sep}{n}{suffix}") # increment path
1106
+
1107
+ if mkdir:
1108
+ path.mkdir(parents=True, exist_ok=True) # make directory
1109
+
1110
+ return path
1111
+
1112
+
1113
+ # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------
1114
+ imshow_ = cv2.imshow # copy to avoid recursion errors
1115
+
1116
+
1117
+ def imread(path, flags=cv2.IMREAD_COLOR):
1118
+ return cv2.imdecode(np.fromfile(path, np.uint8), flags)
1119
+
1120
+
1121
+ def imwrite(path, im):
1122
+ try:
1123
+ cv2.imencode(Path(path).suffix, im)[1].tofile(path)
1124
+ return True
1125
+ except Exception:
1126
+ return False
1127
+
1128
+
1129
+ def imshow(path, im):
1130
+ imshow_(path.encode('unicode_escape').decode(), im)
1131
+
1132
+
1133
+ cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine
1134
+
1135
+ # Variables ------------------------------------------------------------------------------------------------------------
plots.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import math
3
+ import os
4
+ from copy import copy
5
+ from pathlib import Path
6
+ from urllib.error import URLError
7
+
8
+ import cv2
9
+ import matplotlib
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import pandas as pd
13
+ import seaborn as sn
14
+ import torch
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ from utils import TryExcept, threaded
18
+ from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
19
+ is_ascii, xywh2xyxy, xyxy2xywh)
20
+ from utils.metrics import fitness
21
+ from utils.segment.general import scale_image
22
+
23
+ # Settings
24
+ RANK = int(os.getenv('RANK', -1))
25
+ matplotlib.rc('font', **{'size': 11})
26
+ matplotlib.use('Agg') # for writing to files only
27
+
28
+
29
+ class Colors:
30
+ # Ultralytics color palette https://ultralytics.com/
31
+ def __init__(self):
32
+ # hex = matplotlib.colors.TABLEAU_COLORS.values()
33
+ hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
34
+ '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
35
+ self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
36
+ self.n = len(self.palette)
37
+
38
+ def __call__(self, i, bgr=False):
39
+ c = self.palette[int(i) % self.n]
40
+ return (c[2], c[1], c[0]) if bgr else c
41
+
42
+ @staticmethod
43
+ def hex2rgb(h): # rgb order (PIL)
44
+ return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
45
+
46
+
47
+ colors = Colors() # create instance for 'from utils.plots import colors'
48
+
49
+
50
+ def check_pil_font(font=FONT, size=10):
51
+ # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
52
+ font = Path(font)
53
+ font = font if font.exists() else (CONFIG_DIR / font.name)
54
+ try:
55
+ return ImageFont.truetype(str(font) if font.exists() else font.name, size)
56
+ except Exception: # download if missing
57
+ try:
58
+ check_font(font)
59
+ return ImageFont.truetype(str(font), size)
60
+ except TypeError:
61
+ check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
62
+ except URLError: # not online
63
+ return ImageFont.load_default()
64
+
65
+
66
+ class Annotator:
67
+ # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
68
+ def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
69
+ assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
70
+ non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
71
+ self.pil = pil or non_ascii
72
+ if self.pil: # use PIL
73
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
74
+ self.draw = ImageDraw.Draw(self.im)
75
+ self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
76
+ size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
77
+ else: # use cv2
78
+ self.im = im
79
+ self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
80
+
81
+ def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
82
+ # Add one xyxy box to image with label
83
+ if self.pil or not is_ascii(label):
84
+ self.draw.rectangle(box, width=self.lw, outline=color) # box
85
+ if label:
86
+ w, h = self.font.getsize(label) # text width, height
87
+ outside = box[1] - h >= 0 # label fits outside box
88
+ self.draw.rectangle(
89
+ (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
90
+ box[1] + 1 if outside else box[1] + h + 1),
91
+ fill=color,
92
+ )
93
+ # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
94
+ self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
95
+ else: # cv2
96
+ p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
97
+ cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
98
+ if label:
99
+ tf = max(self.lw - 1, 1) # font thickness
100
+ w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
101
+ outside = p1[1] - h >= 3
102
+ p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
103
+ cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
104
+ cv2.putText(self.im,
105
+ label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
106
+ 0,
107
+ self.lw / 3,
108
+ txt_color,
109
+ thickness=tf,
110
+ lineType=cv2.LINE_AA)
111
+
112
+ def masks(self, masks, colors, im_gpu=None, alpha=0.5):
113
+ """Plot masks at once.
114
+ Args:
115
+ masks (tensor): predicted masks on cuda, shape: [n, h, w]
116
+ colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
117
+ im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
118
+ alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
119
+ """
120
+ if self.pil:
121
+ # convert to numpy first
122
+ self.im = np.asarray(self.im).copy()
123
+ if im_gpu is None:
124
+ # Add multiple masks of shape(h,w,n) with colors list([r,g,b], [r,g,b], ...)
125
+ if len(masks) == 0:
126
+ return
127
+ if isinstance(masks, torch.Tensor):
128
+ masks = torch.as_tensor(masks, dtype=torch.uint8)
129
+ masks = masks.permute(1, 2, 0).contiguous()
130
+ masks = masks.cpu().numpy()
131
+ # masks = np.ascontiguousarray(masks.transpose(1, 2, 0))
132
+ masks = scale_image(masks.shape[:2], masks, self.im.shape)
133
+ masks = np.asarray(masks, dtype=np.float32)
134
+ colors = np.asarray(colors, dtype=np.float32) # shape(n,3)
135
+ s = masks.sum(2, keepdims=True).clip(0, 1) # add all masks together
136
+ masks = (masks @ colors).clip(0, 255) # (h,w,n) @ (n,3) = (h,w,3)
137
+ self.im[:] = masks * alpha + self.im * (1 - s * alpha)
138
+ else:
139
+ if len(masks) == 0:
140
+ self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
141
+ colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
142
+ colors = colors[:, None, None] # shape(n,1,1,3)
143
+ masks = masks.unsqueeze(3) # shape(n,h,w,1)
144
+ masks_color = masks * (colors * alpha) # shape(n,h,w,3)
145
+
146
+ inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
147
+ mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
148
+
149
+ im_gpu = im_gpu.flip(dims=[0]) # flip channel
150
+ im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
151
+ im_gpu = im_gpu * inv_alph_masks[-1] + mcs
152
+ im_mask = (im_gpu * 255).byte().cpu().numpy()
153
+ self.im[:] = scale_image(im_gpu.shape, im_mask, self.im.shape)
154
+ if self.pil:
155
+ # convert im back to PIL and update draw
156
+ self.fromarray(self.im)
157
+
158
+ def rectangle(self, xy, fill=None, outline=None, width=1):
159
+ # Add rectangle to image (PIL-only)
160
+ self.draw.rectangle(xy, fill, outline, width)
161
+
162
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
163
+ # Add text to image (PIL-only)
164
+ if anchor == 'bottom': # start y from font bottom
165
+ w, h = self.font.getsize(text) # text width, height
166
+ xy[1] += 1 - h
167
+ self.draw.text(xy, text, fill=txt_color, font=self.font)
168
+
169
+ def fromarray(self, im):
170
+ # Update self.im from a numpy array
171
+ self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
172
+ self.draw = ImageDraw.Draw(self.im)
173
+
174
+ def result(self):
175
+ # Return annotated image as array
176
+ return np.asarray(self.im)
177
+
178
+
179
+ def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
180
+ """
181
+ x: Features to be visualized
182
+ module_type: Module type
183
+ stage: Module stage within model
184
+ n: Maximum number of feature maps to plot
185
+ save_dir: Directory to save results
186
+ """
187
+ if 'Detect' not in module_type:
188
+ batch, channels, height, width = x.shape # batch, channels, height, width
189
+ if height > 1 and width > 1:
190
+ f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
191
+
192
+ blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
193
+ n = min(n, channels) # number of plots
194
+ fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
195
+ ax = ax.ravel()
196
+ plt.subplots_adjust(wspace=0.05, hspace=0.05)
197
+ for i in range(n):
198
+ ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
199
+ ax[i].axis('off')
200
+
201
+ LOGGER.info(f'Saving {f}... ({n}/{channels})')
202
+ plt.savefig(f, dpi=300, bbox_inches='tight')
203
+ plt.close()
204
+ np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
205
+
206
+
207
+ def hist2d(x, y, n=100):
208
+ # 2d histogram used in labels.png and evolve.png
209
+ xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
210
+ hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
211
+ xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
212
+ yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
213
+ return np.log(hist[xidx, yidx])
214
+
215
+
216
+ def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
217
+ from scipy.signal import butter, filtfilt
218
+
219
+ # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
220
+ def butter_lowpass(cutoff, fs, order):
221
+ nyq = 0.5 * fs
222
+ normal_cutoff = cutoff / nyq
223
+ return butter(order, normal_cutoff, btype='low', analog=False)
224
+
225
+ b, a = butter_lowpass(cutoff, fs, order=order)
226
+ return filtfilt(b, a, data) # forward-backward filter
227
+
228
+
229
+ def output_to_target(output, max_det=300):
230
+ # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
231
+ targets = []
232
+ for i, o in enumerate(output):
233
+ box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
234
+ j = torch.full((conf.shape[0], 1), i)
235
+ targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
236
+ return torch.cat(targets, 0).numpy()
237
+
238
+
239
+ @threaded
240
+ def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
241
+ # Plot image grid with labels
242
+ if isinstance(images, torch.Tensor):
243
+ images = images.cpu().float().numpy()
244
+ if isinstance(targets, torch.Tensor):
245
+ targets = targets.cpu().numpy()
246
+
247
+ max_size = 1920 # max image size
248
+ max_subplots = 16 # max image subplots, i.e. 4x4
249
+ bs, _, h, w = images.shape # batch size, _, height, width
250
+ bs = min(bs, max_subplots) # limit plot images
251
+ ns = np.ceil(bs ** 0.5) # number of subplots (square)
252
+ if np.max(images[0]) <= 1:
253
+ images *= 255 # de-normalise (optional)
254
+
255
+ # Build Image
256
+ mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
257
+ for i, im in enumerate(images):
258
+ if i == max_subplots: # if last batch has fewer images than we expect
259
+ break
260
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
261
+ im = im.transpose(1, 2, 0)
262
+ mosaic[y:y + h, x:x + w, :] = im
263
+
264
+ # Resize (optional)
265
+ scale = max_size / ns / max(h, w)
266
+ if scale < 1:
267
+ h = math.ceil(scale * h)
268
+ w = math.ceil(scale * w)
269
+ mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
270
+
271
+ # Annotate
272
+ fs = int((h + w) * ns * 0.01) # font size
273
+ annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
274
+ for i in range(i + 1):
275
+ x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
276
+ annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
277
+ if paths:
278
+ annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
279
+ if len(targets) > 0:
280
+ ti = targets[targets[:, 0] == i] # image targets
281
+ boxes = xywh2xyxy(ti[:, 2:6]).T
282
+ classes = ti[:, 1].astype('int')
283
+ labels = ti.shape[1] == 6 # labels if no conf column
284
+ conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
285
+
286
+ if boxes.shape[1]:
287
+ if boxes.max() <= 1.01: # if normalized with tolerance 0.01
288
+ boxes[[0, 2]] *= w # scale to pixels
289
+ boxes[[1, 3]] *= h
290
+ elif scale < 1: # absolute coords need scale if image scales
291
+ boxes *= scale
292
+ boxes[[0, 2]] += x
293
+ boxes[[1, 3]] += y
294
+ for j, box in enumerate(boxes.T.tolist()):
295
+ cls = classes[j]
296
+ color = colors(cls)
297
+ cls = names[cls] if names else cls
298
+ if labels or conf[j] > 0.25: # 0.25 conf thresh
299
+ label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
300
+ annotator.box_label(box, label, color=color)
301
+ annotator.im.save(fname) # save
302
+
303
+
304
+ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
305
+ # Plot LR simulating training for full epochs
306
+ optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
307
+ y = []
308
+ for _ in range(epochs):
309
+ scheduler.step()
310
+ y.append(optimizer.param_groups[0]['lr'])
311
+ plt.plot(y, '.-', label='LR')
312
+ plt.xlabel('epoch')
313
+ plt.ylabel('LR')
314
+ plt.grid()
315
+ plt.xlim(0, epochs)
316
+ plt.ylim(0)
317
+ plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
318
+ plt.close()
319
+
320
+
321
+ def plot_val_txt(): # from utils.plots import *; plot_val()
322
+ # Plot val.txt histograms
323
+ x = np.loadtxt('val.txt', dtype=np.float32)
324
+ box = xyxy2xywh(x[:, :4])
325
+ cx, cy = box[:, 0], box[:, 1]
326
+
327
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
328
+ ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
329
+ ax.set_aspect('equal')
330
+ plt.savefig('hist2d.png', dpi=300)
331
+
332
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
333
+ ax[0].hist(cx, bins=600)
334
+ ax[1].hist(cy, bins=600)
335
+ plt.savefig('hist1d.png', dpi=200)
336
+
337
+
338
+ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
339
+ # Plot targets.txt histograms
340
+ x = np.loadtxt('targets.txt', dtype=np.float32).T
341
+ s = ['x targets', 'y targets', 'width targets', 'height targets']
342
+ fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
343
+ ax = ax.ravel()
344
+ for i in range(4):
345
+ ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
346
+ ax[i].legend()
347
+ ax[i].set_title(s[i])
348
+ plt.savefig('targets.jpg', dpi=200)
349
+
350
+
351
+ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
352
+ # Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
353
+ save_dir = Path(file).parent if file else Path(dir)
354
+ plot2 = False # plot additional results
355
+ if plot2:
356
+ ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
357
+
358
+ fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
359
+ # for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
360
+ for f in sorted(save_dir.glob('study*.txt')):
361
+ y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
362
+ x = np.arange(y.shape[1]) if x is None else np.array(x)
363
+ if plot2:
364
+ s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
365
+ for i in range(7):
366
+ ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
367
+ ax[i].set_title(s[i])
368
+
369
+ j = y[3].argmax() + 1
370
+ ax2.plot(y[5, 1:j],
371
+ y[3, 1:j] * 1E2,
372
+ '.-',
373
+ linewidth=2,
374
+ markersize=8,
375
+ label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
376
+
377
+ ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
378
+ 'k.-',
379
+ linewidth=2,
380
+ markersize=8,
381
+ alpha=.25,
382
+ label='EfficientDet')
383
+
384
+ ax2.grid(alpha=0.2)
385
+ ax2.set_yticks(np.arange(20, 60, 5))
386
+ ax2.set_xlim(0, 57)
387
+ ax2.set_ylim(25, 55)
388
+ ax2.set_xlabel('GPU Speed (ms/img)')
389
+ ax2.set_ylabel('COCO AP val')
390
+ ax2.legend(loc='lower right')
391
+ f = save_dir / 'study.png'
392
+ print(f'Saving {f}...')
393
+ plt.savefig(f, dpi=300)
394
+
395
+
396
+ @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
397
+ def plot_labels(labels, names=(), save_dir=Path('')):
398
+ # plot dataset labels
399
+ LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
400
+ c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
401
+ nc = int(c.max() + 1) # number of classes
402
+ x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
403
+
404
+ # seaborn correlogram
405
+ sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
406
+ plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
407
+ plt.close()
408
+
409
+ # matplotlib labels
410
+ matplotlib.use('svg') # faster
411
+ ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
412
+ y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
413
+ with contextlib.suppress(Exception): # color histogram bars by class
414
+ [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
415
+ ax[0].set_ylabel('instances')
416
+ if 0 < len(names) < 30:
417
+ ax[0].set_xticks(range(len(names)))
418
+ ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
419
+ else:
420
+ ax[0].set_xlabel('classes')
421
+ sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
422
+ sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
423
+
424
+ # rectangles
425
+ labels[:, 1:3] = 0.5 # center
426
+ labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
427
+ img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
428
+ for cls, *box in labels[:1000]:
429
+ ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
430
+ ax[1].imshow(img)
431
+ ax[1].axis('off')
432
+
433
+ for a in [0, 1, 2, 3]:
434
+ for s in ['top', 'right', 'left', 'bottom']:
435
+ ax[a].spines[s].set_visible(False)
436
+
437
+ plt.savefig(save_dir / 'labels.jpg', dpi=200)
438
+ matplotlib.use('Agg')
439
+ plt.close()
440
+
441
+
442
+ def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
443
+ # Show classification image grid with labels (optional) and predictions (optional)
444
+ from utils.augmentations import denormalize
445
+
446
+ names = names or [f'class{i}' for i in range(1000)]
447
+ blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
448
+ dim=0) # select batch index 0, block by channels
449
+ n = min(len(blocks), nmax) # number of plots
450
+ m = min(8, round(n ** 0.5)) # 8 x 8 default
451
+ fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
452
+ ax = ax.ravel() if m > 1 else [ax]
453
+ # plt.subplots_adjust(wspace=0.05, hspace=0.05)
454
+ for i in range(n):
455
+ ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
456
+ ax[i].axis('off')
457
+ if labels is not None:
458
+ s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
459
+ ax[i].set_title(s, fontsize=8, verticalalignment='top')
460
+ plt.savefig(f, dpi=300, bbox_inches='tight')
461
+ plt.close()
462
+ if verbose:
463
+ LOGGER.info(f"Saving {f}")
464
+ if labels is not None:
465
+ LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
466
+ if pred is not None:
467
+ LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
468
+ return f
469
+
470
+
471
+ def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
472
+ # Plot evolve.csv hyp evolution results
473
+ evolve_csv = Path(evolve_csv)
474
+ data = pd.read_csv(evolve_csv)
475
+ keys = [x.strip() for x in data.columns]
476
+ x = data.values
477
+ f = fitness(x)
478
+ j = np.argmax(f) # max fitness index
479
+ plt.figure(figsize=(10, 12), tight_layout=True)
480
+ matplotlib.rc('font', **{'size': 8})
481
+ print(f'Best results from row {j} of {evolve_csv}:')
482
+ for i, k in enumerate(keys[7:]):
483
+ v = x[:, 7 + i]
484
+ mu = v[j] # best single result
485
+ plt.subplot(6, 5, i + 1)
486
+ plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
487
+ plt.plot(mu, f.max(), 'k+', markersize=15)
488
+ plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
489
+ if i % 5 != 0:
490
+ plt.yticks([])
491
+ print(f'{k:>15}: {mu:.3g}')
492
+ f = evolve_csv.with_suffix('.png') # filename
493
+ plt.savefig(f, dpi=200)
494
+ plt.close()
495
+ print(f'Saved {f}')
496
+
497
+
498
+ def plot_results(file='path/to/results.csv', dir=''):
499
+ # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
500
+ save_dir = Path(file).parent if file else Path(dir)
501
+ fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
502
+ ax = ax.ravel()
503
+ files = list(save_dir.glob('results*.csv'))
504
+ assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
505
+ for f in files:
506
+ try:
507
+ data = pd.read_csv(f)
508
+ s = [x.strip() for x in data.columns]
509
+ x = data.values[:, 0]
510
+ for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
511
+ y = data.values[:, j].astype('float')
512
+ # y[y == 0] = np.nan # don't show zero values
513
+ ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8)
514
+ ax[i].set_title(s[j], fontsize=12)
515
+ # if j in [8, 9, 10]: # share train and val loss y axes
516
+ # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
517
+ except Exception as e:
518
+ LOGGER.info(f'Warning: Plotting error for {f}: {e}')
519
+ ax[1].legend()
520
+ fig.savefig(save_dir / 'results.png', dpi=200)
521
+ plt.close()
522
+
523
+
524
+ def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
525
+ # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
526
+ ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
527
+ s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
528
+ files = list(Path(save_dir).glob('frames*.txt'))
529
+ for fi, f in enumerate(files):
530
+ try:
531
+ results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
532
+ n = results.shape[1] # number of rows
533
+ x = np.arange(start, min(stop, n) if stop else n)
534
+ results = results[:, x]
535
+ t = (results[0] - results[0].min()) # set t0=0s
536
+ results[0] = x
537
+ for i, a in enumerate(ax):
538
+ if i < len(results):
539
+ label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
540
+ a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
541
+ a.set_title(s[i])
542
+ a.set_xlabel('time (s)')
543
+ # if fi == len(files) - 1:
544
+ # a.set_ylim(bottom=0)
545
+ for side in ['top', 'right']:
546
+ a.spines[side].set_visible(False)
547
+ else:
548
+ a.remove()
549
+ except Exception as e:
550
+ print(f'Warning: Plotting error for {f}; {e}')
551
+ ax[1].legend()
552
+ plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
553
+
554
+
555
+ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
556
+ # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
557
+ xyxy = torch.tensor(xyxy).view(-1, 4)
558
+ b = xyxy2xywh(xyxy) # boxes
559
+ if square:
560
+ b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
561
+ b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
562
+ xyxy = xywh2xyxy(b).long()
563
+ clip_boxes(xyxy, im.shape)
564
+ crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
565
+ if save:
566
+ file.parent.mkdir(parents=True, exist_ok=True) # make directory
567
+ f = str(increment_path(file).with_suffix('.jpg'))
568
+ # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
569
+ Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
570
+ return crop
torch_utils.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import platform
4
+ import subprocess
5
+ import time
6
+ import warnings
7
+ from contextlib import contextmanager
8
+ from copy import deepcopy
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+
17
+ from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
18
+ from utils.lion import Lion
19
+
20
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
21
+ RANK = int(os.getenv('RANK', -1))
22
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
23
+
24
+ try:
25
+ import thop # for FLOPs computation
26
+ except ImportError:
27
+ thop = None
28
+
29
+ # Suppress PyTorch warnings
30
+ warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
31
+ warnings.filterwarnings('ignore', category=UserWarning)
32
+
33
+
34
+ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
35
+ # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
36
+ def decorate(fn):
37
+ return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
38
+
39
+ return decorate
40
+
41
+
42
+ def smartCrossEntropyLoss(label_smoothing=0.0):
43
+ # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
44
+ if check_version(torch.__version__, '1.10.0'):
45
+ return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
46
+ if label_smoothing > 0:
47
+ LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
48
+ return nn.CrossEntropyLoss()
49
+
50
+
51
+ def smart_DDP(model):
52
+ # Model DDP creation with checks
53
+ assert not check_version(torch.__version__, '1.12.0', pinned=True), \
54
+ 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
55
+ 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
56
+ if check_version(torch.__version__, '1.11.0'):
57
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
58
+ else:
59
+ return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
60
+
61
+
62
+ def reshape_classifier_output(model, n=1000):
63
+ # Update a TorchVision classification model to class count 'n' if required
64
+ from models.common import Classify
65
+ name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
66
+ if isinstance(m, Classify): # YOLOv5 Classify() head
67
+ if m.linear.out_features != n:
68
+ m.linear = nn.Linear(m.linear.in_features, n)
69
+ elif isinstance(m, nn.Linear): # ResNet, EfficientNet
70
+ if m.out_features != n:
71
+ setattr(model, name, nn.Linear(m.in_features, n))
72
+ elif isinstance(m, nn.Sequential):
73
+ types = [type(x) for x in m]
74
+ if nn.Linear in types:
75
+ i = types.index(nn.Linear) # nn.Linear index
76
+ if m[i].out_features != n:
77
+ m[i] = nn.Linear(m[i].in_features, n)
78
+ elif nn.Conv2d in types:
79
+ i = types.index(nn.Conv2d) # nn.Conv2d index
80
+ if m[i].out_channels != n:
81
+ m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
82
+
83
+
84
+ @contextmanager
85
+ def torch_distributed_zero_first(local_rank: int):
86
+ # Decorator to make all processes in distributed training wait for each local_master to do something
87
+ if local_rank not in [-1, 0]:
88
+ dist.barrier(device_ids=[local_rank])
89
+ yield
90
+ if local_rank == 0:
91
+ dist.barrier(device_ids=[0])
92
+
93
+
94
+ def device_count():
95
+ # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
96
+ assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
97
+ try:
98
+ cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
99
+ return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
100
+ except Exception:
101
+ return 0
102
+
103
+
104
+ def select_device(device='', batch_size=0, newline=True):
105
+ # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
106
+ s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
107
+ device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
108
+ cpu = device == 'cpu'
109
+ mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
110
+ if cpu or mps:
111
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
112
+ elif device: # non-cpu device requested
113
+ os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
114
+ assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
115
+ f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
116
+
117
+ if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
118
+ devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
119
+ n = len(devices) # device count
120
+ if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
121
+ assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
122
+ space = ' ' * (len(s) + 1)
123
+ for i, d in enumerate(devices):
124
+ p = torch.cuda.get_device_properties(i)
125
+ s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
126
+ arg = 'cuda:0'
127
+ elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
128
+ s += 'MPS\n'
129
+ arg = 'mps'
130
+ else: # revert to CPU
131
+ s += 'CPU\n'
132
+ arg = 'cpu'
133
+
134
+ if not newline:
135
+ s = s.rstrip()
136
+ LOGGER.info(s)
137
+ return torch.device(arg)
138
+
139
+
140
+ def time_sync():
141
+ # PyTorch-accurate time
142
+ if torch.cuda.is_available():
143
+ torch.cuda.synchronize()
144
+ return time.time()
145
+
146
+
147
+ def profile(input, ops, n=10, device=None):
148
+ """ YOLOv5 speed/memory/FLOPs profiler
149
+ Usage:
150
+ input = torch.randn(16, 3, 640, 640)
151
+ m1 = lambda x: x * torch.sigmoid(x)
152
+ m2 = nn.SiLU()
153
+ profile(input, [m1, m2], n=100) # profile over 100 iterations
154
+ """
155
+ results = []
156
+ if not isinstance(device, torch.device):
157
+ device = select_device(device)
158
+ print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
159
+ f"{'input':>24s}{'output':>24s}")
160
+
161
+ for x in input if isinstance(input, list) else [input]:
162
+ x = x.to(device)
163
+ x.requires_grad = True
164
+ for m in ops if isinstance(ops, list) else [ops]:
165
+ m = m.to(device) if hasattr(m, 'to') else m # device
166
+ m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
167
+ tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
168
+ try:
169
+ flops = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 # GFLOPs
170
+ except Exception:
171
+ flops = 0
172
+
173
+ try:
174
+ for _ in range(n):
175
+ t[0] = time_sync()
176
+ y = m(x)
177
+ t[1] = time_sync()
178
+ try:
179
+ _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
180
+ t[2] = time_sync()
181
+ except Exception: # no backward method
182
+ # print(e) # for debug
183
+ t[2] = float('nan')
184
+ tf += (t[1] - t[0]) * 1000 / n # ms per op forward
185
+ tb += (t[2] - t[1]) * 1000 / n # ms per op backward
186
+ mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
187
+ s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
188
+ p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
189
+ print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
190
+ results.append([p, flops, mem, tf, tb, s_in, s_out])
191
+ except Exception as e:
192
+ print(e)
193
+ results.append(None)
194
+ torch.cuda.empty_cache()
195
+ return results
196
+
197
+
198
+ def is_parallel(model):
199
+ # Returns True if model is of type DP or DDP
200
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
201
+
202
+
203
+ def de_parallel(model):
204
+ # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
205
+ return model.module if is_parallel(model) else model
206
+
207
+
208
+ def initialize_weights(model):
209
+ for m in model.modules():
210
+ t = type(m)
211
+ if t is nn.Conv2d:
212
+ pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
213
+ elif t is nn.BatchNorm2d:
214
+ m.eps = 1e-3
215
+ m.momentum = 0.03
216
+ elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
217
+ m.inplace = True
218
+
219
+
220
+ def find_modules(model, mclass=nn.Conv2d):
221
+ # Finds layer indices matching module class 'mclass'
222
+ return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
223
+
224
+
225
+ def sparsity(model):
226
+ # Return global model sparsity
227
+ a, b = 0, 0
228
+ for p in model.parameters():
229
+ a += p.numel()
230
+ b += (p == 0).sum()
231
+ return b / a
232
+
233
+
234
+ def prune(model, amount=0.3):
235
+ # Prune model to requested global sparsity
236
+ import torch.nn.utils.prune as prune
237
+ for name, m in model.named_modules():
238
+ if isinstance(m, nn.Conv2d):
239
+ prune.l1_unstructured(m, name='weight', amount=amount) # prune
240
+ prune.remove(m, 'weight') # make permanent
241
+ LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
242
+
243
+
244
+ def fuse_conv_and_bn(conv, bn):
245
+ # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
246
+ fusedconv = nn.Conv2d(conv.in_channels,
247
+ conv.out_channels,
248
+ kernel_size=conv.kernel_size,
249
+ stride=conv.stride,
250
+ padding=conv.padding,
251
+ dilation=conv.dilation,
252
+ groups=conv.groups,
253
+ bias=True).requires_grad_(False).to(conv.weight.device)
254
+
255
+ # Prepare filters
256
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
257
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
258
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
259
+
260
+ # Prepare spatial bias
261
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
262
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
263
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
264
+
265
+ return fusedconv
266
+
267
+
268
+ def model_info(model, verbose=False, imgsz=640):
269
+ # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
270
+ n_p = sum(x.numel() for x in model.parameters()) # number parameters
271
+ n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
272
+ if verbose:
273
+ print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
274
+ for i, (name, p) in enumerate(model.named_parameters()):
275
+ name = name.replace('module_list.', '')
276
+ print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
277
+ (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
278
+
279
+ try: # FLOPs
280
+ p = next(model.parameters())
281
+ stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
282
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
283
+ flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
284
+ imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
285
+ fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
286
+ except Exception:
287
+ fs = ''
288
+
289
+ name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
290
+ LOGGER.info(f"{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
291
+
292
+
293
+ def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
294
+ # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
295
+ if ratio == 1.0:
296
+ return img
297
+ h, w = img.shape[2:]
298
+ s = (int(h * ratio), int(w * ratio)) # new size
299
+ img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
300
+ if not same_shape: # pad/crop img
301
+ h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
302
+ return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
303
+
304
+
305
+ def copy_attr(a, b, include=(), exclude=()):
306
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
307
+ for k, v in b.__dict__.items():
308
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
309
+ continue
310
+ else:
311
+ setattr(a, k, v)
312
+
313
+
314
+ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
315
+ # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
316
+ g = [], [], [] # optimizer parameter groups
317
+ bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
318
+ #for v in model.modules():
319
+ # for p_name, p in v.named_parameters(recurse=0):
320
+ # if p_name == 'bias': # bias (no decay)
321
+ # g[2].append(p)
322
+ # elif p_name == 'weight' and isinstance(v, bn): # weight (no decay)
323
+ # g[1].append(p)
324
+ # else:
325
+ # g[0].append(p) # weight (with decay)
326
+
327
+ for v in model.modules():
328
+ if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
329
+ g[2].append(v.bias)
330
+ if isinstance(v, bn): # weight (no decay)
331
+ g[1].append(v.weight)
332
+ elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
333
+ g[0].append(v.weight)
334
+
335
+ if hasattr(v, 'im'):
336
+ if hasattr(v.im, 'implicit'):
337
+ g[1].append(v.im.implicit)
338
+ else:
339
+ for iv in v.im:
340
+ g[1].append(iv.implicit)
341
+ if hasattr(v, 'ia'):
342
+ if hasattr(v.ia, 'implicit'):
343
+ g[1].append(v.ia.implicit)
344
+ else:
345
+ for iv in v.ia:
346
+ g[1].append(iv.implicit)
347
+
348
+ if hasattr(v, 'im2'):
349
+ if hasattr(v.im2, 'implicit'):
350
+ g[1].append(v.im2.implicit)
351
+ else:
352
+ for iv in v.im2:
353
+ g[1].append(iv.implicit)
354
+ if hasattr(v, 'ia2'):
355
+ if hasattr(v.ia2, 'implicit'):
356
+ g[1].append(v.ia2.implicit)
357
+ else:
358
+ for iv in v.ia2:
359
+ g[1].append(iv.implicit)
360
+
361
+ if hasattr(v, 'im3'):
362
+ if hasattr(v.im3, 'implicit'):
363
+ g[1].append(v.im3.implicit)
364
+ else:
365
+ for iv in v.im3:
366
+ g[1].append(iv.implicit)
367
+ if hasattr(v, 'ia3'):
368
+ if hasattr(v.ia3, 'implicit'):
369
+ g[1].append(v.ia3.implicit)
370
+ else:
371
+ for iv in v.ia3:
372
+ g[1].append(iv.implicit)
373
+
374
+ if hasattr(v, 'im4'):
375
+ if hasattr(v.im4, 'implicit'):
376
+ g[1].append(v.im4.implicit)
377
+ else:
378
+ for iv in v.im4:
379
+ g[1].append(iv.implicit)
380
+ if hasattr(v, 'ia4'):
381
+ if hasattr(v.ia4, 'implicit'):
382
+ g[1].append(v.ia4.implicit)
383
+ else:
384
+ for iv in v.ia4:
385
+ g[1].append(iv.implicit)
386
+
387
+ if hasattr(v, 'im5'):
388
+ if hasattr(v.im5, 'implicit'):
389
+ g[1].append(v.im5.implicit)
390
+ else:
391
+ for iv in v.im5:
392
+ g[1].append(iv.implicit)
393
+ if hasattr(v, 'ia5'):
394
+ if hasattr(v.ia5, 'implicit'):
395
+ g[1].append(v.ia5.implicit)
396
+ else:
397
+ for iv in v.ia5:
398
+ g[1].append(iv.implicit)
399
+
400
+ if hasattr(v, 'im6'):
401
+ if hasattr(v.im6, 'implicit'):
402
+ g[1].append(v.im6.implicit)
403
+ else:
404
+ for iv in v.im6:
405
+ g[1].append(iv.implicit)
406
+ if hasattr(v, 'ia6'):
407
+ if hasattr(v.ia6, 'implicit'):
408
+ g[1].append(v.ia6.implicit)
409
+ else:
410
+ for iv in v.ia6:
411
+ g[1].append(iv.implicit)
412
+
413
+ if hasattr(v, 'im7'):
414
+ if hasattr(v.im7, 'implicit'):
415
+ g[1].append(v.im7.implicit)
416
+ else:
417
+ for iv in v.im7:
418
+ g[1].append(iv.implicit)
419
+ if hasattr(v, 'ia7'):
420
+ if hasattr(v.ia7, 'implicit'):
421
+ g[1].append(v.ia7.implicit)
422
+ else:
423
+ for iv in v.ia7:
424
+ g[1].append(iv.implicit)
425
+
426
+ if name == 'Adam':
427
+ optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
428
+ elif name == 'AdamW':
429
+ optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0, amsgrad=True)
430
+ elif name == 'RMSProp':
431
+ optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
432
+ elif name == 'SGD':
433
+ optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
434
+ elif name == 'LION':
435
+ optimizer = Lion(g[2], lr=lr, betas=(momentum, 0.99), weight_decay=0.0)
436
+ else:
437
+ raise NotImplementedError(f'Optimizer {name} not implemented.')
438
+
439
+ optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
440
+ optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
441
+ LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
442
+ f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias")
443
+ return optimizer
444
+
445
+
446
+ def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
447
+ # YOLOv5 torch.hub.load() wrapper with smart error/issue handling
448
+ if check_version(torch.__version__, '1.9.1'):
449
+ kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
450
+ if check_version(torch.__version__, '1.12.0'):
451
+ kwargs['trust_repo'] = True # argument required starting in torch 0.12
452
+ try:
453
+ return torch.hub.load(repo, model, **kwargs)
454
+ except Exception:
455
+ return torch.hub.load(repo, model, force_reload=True, **kwargs)
456
+
457
+
458
+ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
459
+ # Resume training from a partially trained checkpoint
460
+ best_fitness = 0.0
461
+ start_epoch = ckpt['epoch'] + 1
462
+ if ckpt['optimizer'] is not None:
463
+ optimizer.load_state_dict(ckpt['optimizer']) # optimizer
464
+ best_fitness = ckpt['best_fitness']
465
+ if ema and ckpt.get('ema'):
466
+ ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
467
+ ema.updates = ckpt['updates']
468
+ if resume:
469
+ assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
470
+ f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
471
+ LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
472
+ if epochs < start_epoch:
473
+ LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
474
+ epochs += ckpt['epoch'] # finetune additional epochs
475
+ return best_fitness, start_epoch, epochs
476
+
477
+
478
+ class EarlyStopping:
479
+ # YOLOv5 simple early stopper
480
+ def __init__(self, patience=30):
481
+ self.best_fitness = 0.0 # i.e. mAP
482
+ self.best_epoch = 0
483
+ self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
484
+ self.possible_stop = False # possible stop may occur next epoch
485
+
486
+ def __call__(self, epoch, fitness):
487
+ if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
488
+ self.best_epoch = epoch
489
+ self.best_fitness = fitness
490
+ delta = epoch - self.best_epoch # epochs without improvement
491
+ self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
492
+ stop = delta >= self.patience # stop training if patience exceeded
493
+ if stop:
494
+ LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
495
+ f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
496
+ f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
497
+ f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
498
+ return stop
499
+
500
+
501
+ class ModelEMA:
502
+ """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
503
+ Keeps a moving average of everything in the model state_dict (parameters and buffers)
504
+ For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
505
+ """
506
+
507
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
508
+ # Create EMA
509
+ self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
510
+ self.updates = updates # number of EMA updates
511
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
512
+ for p in self.ema.parameters():
513
+ p.requires_grad_(False)
514
+
515
+ def update(self, model):
516
+ # Update EMA parameters
517
+ self.updates += 1
518
+ d = self.decay(self.updates)
519
+
520
+ msd = de_parallel(model).state_dict() # model state_dict
521
+ for k, v in self.ema.state_dict().items():
522
+ if v.dtype.is_floating_point: # true for FP16 and FP32
523
+ v *= d
524
+ v += (1 - d) * msd[k].detach()
525
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
526
+
527
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
528
+ # Update EMA attributes
529
+ copy_attr(self.ema, model, include, exclude)