Van commited on
Commit
f5f9ab3
·
1 Parent(s): 392c45b

with repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +2 -1
  2. best.pt +0 -3
  3. datasets.py +0 -38
  4. yolov9/README.md +153 -0
  5. yolov9/benchmarks.py +142 -0
  6. yolov9/classify/predict.py +224 -0
  7. yolov9/classify/train.py +333 -0
  8. yolov9/classify/val.py +170 -0
  9. yolov9/data/coco.yaml +106 -0
  10. yolov9/data/hyps/hyp.scratch-high.yaml +30 -0
  11. yolov9/detect.py +232 -0
  12. yolov9/export.py +606 -0
  13. yolov9/figure/performance.png +0 -0
  14. yolov9/hubconf.py +107 -0
  15. yolov9/models/__init__.py +1 -0
  16. common.py → yolov9/models/common.py +0 -0
  17. yolov9/models/detect/gelan-c.yaml +80 -0
  18. yolov9/models/detect/gelan-e.yaml +121 -0
  19. yolov9/models/detect/gelan.yaml +80 -0
  20. yolov9/models/detect/yolov7-af.yaml +137 -0
  21. yolov9/models/detect/yolov9-c.yaml +124 -0
  22. yolov9/models/detect/yolov9-e.yaml +144 -0
  23. yolov9/models/detect/yolov9.yaml +117 -0
  24. experimental.py → yolov9/models/experimental.py +0 -0
  25. yolov9/models/hub/anchors.yaml +59 -0
  26. yolov9/models/hub/yolov3-spp.yaml +51 -0
  27. yolov9/models/hub/yolov3-tiny.yaml +41 -0
  28. yolov9/models/hub/yolov3.yaml +51 -0
  29. yolov9/models/panoptic/yolov7-af-pan.yaml +137 -0
  30. yolov9/models/segment/yolov7-af-seg.yaml +136 -0
  31. tf.py → yolov9/models/tf.py +0 -0
  32. yolo.py → yolov9/models/yolo.py +6 -6
  33. yolov9/panoptic/predict.py +246 -0
  34. yolov9/panoptic/train.py +662 -0
  35. yolov9/panoptic/val.py +597 -0
  36. yolov9/requirements.txt +47 -0
  37. yolov9/scripts/get_coco.sh +22 -0
  38. yolov9/segment/predict.py +246 -0
  39. yolov9/segment/train.py +646 -0
  40. yolov9/segment/val.py +457 -0
  41. yolov9/train.py +634 -0
  42. yolov9/train_dual.py +644 -0
  43. yolov9/train_triple.py +636 -0
  44. yolov9/utils/__init__.py +75 -0
  45. yolov9/utils/activations.py +98 -0
  46. yolov9/utils/augmentations.py +395 -0
  47. yolov9/utils/autoanchor.py +164 -0
  48. yolov9/utils/autobatch.py +67 -0
  49. yolov9/utils/callbacks.py +71 -0
  50. yolov9/utils/dataloaders.py +1217 -0
app.py CHANGED
@@ -8,7 +8,8 @@ from pytorch_grad_cam import GradCAM
8
  from pytorch_grad_cam.utils.image import show_cam_on_image
9
  import gradio as gr
10
  from huggingface_hub import hf_hub_download
11
- from yolo import Model
 
12
 
13
  model = Model('best.pt')
14
 
 
8
  from pytorch_grad_cam.utils.image import show_cam_on_image
9
  import gradio as gr
10
  from huggingface_hub import hf_hub_download
11
+ import yolov9
12
+ from yolov9.Yolo import Model
13
 
14
  model = Model('best.pt')
15
 
best.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bbb8ecf41921ab2db120751f365c0de8cba9c92ab0e612c347c27dfe758d2c30
3
- size 204626214
 
 
 
 
datasets.py DELETED
@@ -1,38 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Module containing wrapper classes for PyTorch Datasets
4
- Author: Shilpaj Bhalerao
5
- Date: Jun 25, 2023
6
- """
7
- # Standard Library Imports
8
- from typing import Tuple
9
-
10
- # Third-Party Imports
11
- from torchvision import datasets, transforms
12
-
13
-
14
- class AlbumDataset(datasets.CIFAR10):
15
- """
16
- Wrapper class to use albumentations library with PyTorch Dataset
17
- """
18
- def __init__(self, root: str = "./data", train: bool = True, download: bool = True, transform: list = None):
19
- """
20
- Constructor
21
- :param root: Directory at which data is stored
22
- :param train: Param to distinguish if data is training or test
23
- :param download: Param to download the dataset from source
24
- :param transform: List of transformation to be performed on the dataset
25
- """
26
- super().__init__(root=root, train=train, download=download, transform=transform)
27
-
28
- def __getitem__(self, index: int) -> Tuple:
29
- """
30
- Method to return image and its label
31
- :param index: Index of image and label in the dataset
32
- """
33
- image, label = self.data[index], self.targets[index]
34
-
35
- if self.transform:
36
- transformed = self.transform(image=image)
37
- image = transformed["image"]
38
- return image, label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolov9/README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ Implementation of paper - [YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information](https://arxiv.org/abs/2402.13616)
4
+
5
+ <div align="center">
6
+ <a href="./">
7
+ <img src="./figure/performance.png" width="79%"/>
8
+ </a>
9
+ </div>
10
+
11
+
12
+ ## Performance
13
+
14
+ MS COCO
15
+
16
+ | Model | Test Size | AP<sup>val</sup> | AP<sub>50</sub><sup>val</sup> | AP<sub>75</sub><sup>val</sup> | Param. | FLOPs |
17
+ | :-- | :-: | :-: | :-: | :-: | :-: | :-: |
18
+ | [**YOLOv9-S**]() | 640 | **46.8%** | **63.4%** | **50.7%** | **7.2M** | **26.7G** |
19
+ | [**YOLOv9-M**]() | 640 | **51.4%** | **68.1%** | **56.1%** | **20.1M** | **76.8G** |
20
+ | [**YOLOv9-C**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c.pt) | 640 | **53.0%** | **70.2%** | **57.8%** | **25.5M** | **102.8G** |
21
+ | [**YOLOv9-E**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-e.pt) | 640 | **55.6%** | **72.8%** | **60.6%** | **58.1M** | **192.5G** |
22
+
23
+ <!-- small and medium models will be released after the paper be accepted and published. -->
24
+
25
+
26
+ ## Installation
27
+
28
+ Docker environment (recommended)
29
+ <details><summary> <b>Expand</b> </summary>
30
+
31
+ ``` shell
32
+ # create the docker container, you can change the share memory size if you have more.
33
+ nvidia-docker run --name yolov9 -it -v your_coco_path/:/coco/ -v your_code_path/:/yolov9 --shm-size=64g nvcr.io/nvidia/pytorch:21.11-py3
34
+
35
+ # apt install required packages
36
+ apt update
37
+ apt install -y zip htop screen libgl1-mesa-glx
38
+
39
+ # pip install required packages
40
+ pip install seaborn thop
41
+
42
+ # go to code folder
43
+ cd /yolov9
44
+ ```
45
+
46
+ </details>
47
+
48
+
49
+ ## Evaluation
50
+
51
+ [`yolov9-c.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c.pt) [`yolov9-e.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-e.pt) [`gelan-c.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c.pt) [`gelan-e.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-e.pt)
52
+
53
+ ``` shell
54
+ # evaluate yolov9 models
55
+ python val_dual.py --data data/coco.yaml --img 640 --batch 32 --conf 0.001 --iou 0.7 --device 0 --weights './yolov9-c.pt' --save-json --name yolov9_c_640_val
56
+
57
+ # evaluate gelan models
58
+ # python val.py --data data/coco.yaml --img 640 --batch 32 --conf 0.001 --iou 0.7 --device 0 --weights './gelan-c.pt' --save-json --name gelan_c_640_val
59
+ ```
60
+
61
+ You will get the results:
62
+
63
+ ```
64
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.530
65
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.702
66
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.578
67
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.362
68
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.585
69
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.693
70
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.392
71
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.652
72
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.702
73
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.541
74
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.760
75
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.844
76
+ ```
77
+
78
+
79
+ ## Training
80
+
81
+ Data preparation
82
+
83
+ ``` shell
84
+ bash scripts/get_coco.sh
85
+ ```
86
+
87
+ * Download MS COCO dataset images ([train](http://images.cocodataset.org/zips/train2017.zip), [val](http://images.cocodataset.org/zips/val2017.zip), [test](http://images.cocodataset.org/zips/test2017.zip)) and [labels](https://github.com/WongKinYiu/yolov7/releases/download/v0.1/coco2017labels-segments.zip). If you have previously used a different version of YOLO, we strongly recommend that you delete `train2017.cache` and `val2017.cache` files, and redownload [labels](https://github.com/WongKinYiu/yolov7/releases/download/v0.1/coco2017labels-segments.zip)
88
+
89
+ Single GPU training
90
+
91
+ ``` shell
92
+ # train yolov9 models
93
+ python train_dual.py --workers 8 --device 0 --batch 16 --data data/coco.yaml --img 640 --cfg models/detect/yolov9-c.yaml --weights '' --name yolov9-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
94
+
95
+ # train gelan models
96
+ # python train.py --workers 8 --device 0 --batch 32 --data data/coco.yaml --img 640 --cfg models/detect/gelan-c.yaml --weights '' --name gelan-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
97
+ ```
98
+
99
+ Multiple GPU training
100
+
101
+ ``` shell
102
+ # train yolov9 models
103
+ python -m torch.distributed.launch --nproc_per_node 8 --master_port 9527 train_dual.py --workers 8 --device 0,1,2,3,4,5,6,7 --sync-bn --batch 128 --data data/coco.yaml --img 640 --cfg models/detect/yolov9-c.yaml --weights '' --name yolov9-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
104
+
105
+ # train gelan models
106
+ # python -m torch.distributed.launch --nproc_per_node 4 --master_port 9527 train.py --workers 8 --device 0,1,2,3 --sync-bn --batch 128 --data data/coco.yaml --img 640 --cfg models/detect/gelan-c.yaml --weights '' --name gelan-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
107
+ ```
108
+
109
+
110
+ ## Re-parameterization
111
+
112
+ Under construction.
113
+
114
+
115
+ ## Citation
116
+
117
+ ```
118
+ @article{wang2024yolov9,
119
+ title={{YOLOv9}: Learning What You Want to Learn Using Programmable Gradient Information},
120
+ author={Wang, Chien-Yao and Liao, Hong-Yuan Mark},
121
+ booktitle={arXiv preprint arXiv:2402.13616},
122
+ year={2024}
123
+ }
124
+ ```
125
+
126
+ ```
127
+ @article{chang2023yolor,
128
+ title={{YOLOR}-Based Multi-Task Learning},
129
+ author={Chang, Hung-Shuo and Wang, Chien-Yao and Wang, Richard Robert and Chou, Gene and Liao, Hong-Yuan Mark},
130
+ journal={arXiv preprint arXiv:2309.16921},
131
+ year={2023}
132
+ }
133
+ ```
134
+
135
+
136
+ ## Teaser
137
+
138
+ Parts of code of [YOLOR-Based Multi-Task Learning](https://arxiv.org/abs/2309.16921) are released in the repository.
139
+
140
+
141
+ ## Acknowledgements
142
+
143
+ <details><summary> <b>Expand</b> </summary>
144
+
145
+ * [https://github.com/AlexeyAB/darknet](https://github.com/AlexeyAB/darknet)
146
+ * [https://github.com/WongKinYiu/yolor](https://github.com/WongKinYiu/yolor)
147
+ * [https://github.com/WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)
148
+ * [https://github.com/VDIGPKU/DynamicDet](https://github.com/VDIGPKU/DynamicDet)
149
+ * [https://github.com/DingXiaoH/RepVGG](https://github.com/DingXiaoH/RepVGG)
150
+ * [https://github.com/ultralytics/yolov5](https://github.com/ultralytics/yolov5)
151
+ * [https://github.com/meituan/YOLOv6](https://github.com/meituan/YOLOv6)
152
+
153
+ </details>
yolov9/benchmarks.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import platform
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ # ROOT = ROOT.relative_to(Path.cwd()) # relative
14
+
15
+ import export
16
+ from models.experimental import attempt_load
17
+ from models.yolo import SegmentationModel
18
+ from segment.val import run as val_seg
19
+ from utils import notebook_init
20
+ from utils.general import LOGGER, check_yaml, file_size, print_args
21
+ from utils.torch_utils import select_device
22
+ from val import run as val_det
23
+
24
+
25
+ def run(
26
+ weights=ROOT / 'yolo.pt', # weights path
27
+ imgsz=640, # inference size (pixels)
28
+ batch_size=1, # batch size
29
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
30
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
31
+ half=False, # use FP16 half-precision inference
32
+ test=False, # test exports only
33
+ pt_only=False, # test PyTorch only
34
+ hard_fail=False, # throw error on benchmark failure
35
+ ):
36
+ y, t = [], time.time()
37
+ device = select_device(device)
38
+ model_type = type(attempt_load(weights, fuse=False)) # DetectionModel, SegmentationModel, etc.
39
+ for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
40
+ try:
41
+ assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
42
+ assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
43
+ if 'cpu' in device.type:
44
+ assert cpu, 'inference not supported on CPU'
45
+ if 'cuda' in device.type:
46
+ assert gpu, 'inference not supported on GPU'
47
+
48
+ # Export
49
+ if f == '-':
50
+ w = weights # PyTorch format
51
+ else:
52
+ w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
53
+ assert suffix in str(w), 'export failed'
54
+
55
+ # Validate
56
+ if model_type == SegmentationModel:
57
+ result = val_seg(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
58
+ metric = result[0][7] # (box(p, r, map50, map), mask(p, r, map50, map), *loss(box, obj, cls))
59
+ else: # DetectionModel:
60
+ result = val_det(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
61
+ metric = result[0][3] # (p, r, map50, map, *loss(box, obj, cls))
62
+ speed = result[2][1] # times (preprocess, inference, postprocess)
63
+ y.append([name, round(file_size(w), 1), round(metric, 4), round(speed, 2)]) # MB, mAP, t_inference
64
+ except Exception as e:
65
+ if hard_fail:
66
+ assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}'
67
+ LOGGER.warning(f'WARNING ⚠️ Benchmark failure for {name}: {e}')
68
+ y.append([name, None, None, None]) # mAP, t_inference
69
+ if pt_only and i == 0:
70
+ break # break after PyTorch
71
+
72
+ # Print results
73
+ LOGGER.info('\n')
74
+ parse_opt()
75
+ notebook_init() # print system info
76
+ c = ['Format', 'Size (MB)', 'mAP50-95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
77
+ py = pd.DataFrame(y, columns=c)
78
+ LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
79
+ LOGGER.info(str(py if map else py.iloc[:, :2]))
80
+ if hard_fail and isinstance(hard_fail, str):
81
+ metrics = py['mAP50-95'].array # values to compare to floor
82
+ floor = eval(hard_fail) # minimum metric floor to pass
83
+ assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: mAP50-95 < floor {floor}'
84
+ return py
85
+
86
+
87
+ def test(
88
+ weights=ROOT / 'yolo.pt', # weights path
89
+ imgsz=640, # inference size (pixels)
90
+ batch_size=1, # batch size
91
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
92
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
93
+ half=False, # use FP16 half-precision inference
94
+ test=False, # test exports only
95
+ pt_only=False, # test PyTorch only
96
+ hard_fail=False, # throw error on benchmark failure
97
+ ):
98
+ y, t = [], time.time()
99
+ device = select_device(device)
100
+ for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
101
+ try:
102
+ w = weights if f == '-' else \
103
+ export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
104
+ assert suffix in str(w), 'export failed'
105
+ y.append([name, True])
106
+ except Exception:
107
+ y.append([name, False]) # mAP, t_inference
108
+
109
+ # Print results
110
+ LOGGER.info('\n')
111
+ parse_opt()
112
+ notebook_init() # print system info
113
+ py = pd.DataFrame(y, columns=['Format', 'Export'])
114
+ LOGGER.info(f'\nExports complete ({time.time() - t:.2f}s)')
115
+ LOGGER.info(str(py))
116
+ return py
117
+
118
+
119
+ def parse_opt():
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
122
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
123
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
124
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
125
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
126
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
127
+ parser.add_argument('--test', action='store_true', help='test exports only')
128
+ parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
129
+ parser.add_argument('--hard-fail', nargs='?', const=True, default=False, help='Exception on error or < min metric')
130
+ opt = parser.parse_args()
131
+ opt.data = check_yaml(opt.data) # check YAML
132
+ print_args(vars(opt))
133
+ return opt
134
+
135
+
136
+ def main(opt):
137
+ test(**vars(opt)) if opt.test else run(**vars(opt))
138
+
139
+
140
+ if __name__ == "__main__":
141
+ opt = parse_opt()
142
+ main(opt)
yolov9/classify/predict.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Run YOLOv5 classification inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
4
+
5
+ Usage - sources:
6
+ $ python classify/predict.py --weights yolov5s-cls.pt --source 0 # webcam
7
+ img.jpg # image
8
+ vid.mp4 # video
9
+ screen # screenshot
10
+ path/ # directory
11
+ 'path/*.jpg' # glob
12
+ 'https://youtu.be/Zgi9g1ksQHc' # YouTube
13
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
14
+
15
+ Usage - formats:
16
+ $ python classify/predict.py --weights yolov5s-cls.pt # PyTorch
17
+ yolov5s-cls.torchscript # TorchScript
18
+ yolov5s-cls.onnx # ONNX Runtime or OpenCV DNN with --dnn
19
+ yolov5s-cls_openvino_model # OpenVINO
20
+ yolov5s-cls.engine # TensorRT
21
+ yolov5s-cls.mlmodel # CoreML (macOS-only)
22
+ yolov5s-cls_saved_model # TensorFlow SavedModel
23
+ yolov5s-cls.pb # TensorFlow GraphDef
24
+ yolov5s-cls.tflite # TensorFlow Lite
25
+ yolov5s-cls_edgetpu.tflite # TensorFlow Edge TPU
26
+ yolov5s-cls_paddle_model # PaddlePaddle
27
+ """
28
+
29
+ import argparse
30
+ import os
31
+ import platform
32
+ import sys
33
+ from pathlib import Path
34
+
35
+ import torch
36
+ import torch.nn.functional as F
37
+
38
+ FILE = Path(__file__).resolve()
39
+ ROOT = FILE.parents[1] # YOLOv5 root directory
40
+ if str(ROOT) not in sys.path:
41
+ sys.path.append(str(ROOT)) # add ROOT to PATH
42
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
43
+
44
+ from models.common import DetectMultiBackend
45
+ from utils.augmentations import classify_transforms
46
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
47
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
48
+ increment_path, print_args, strip_optimizer)
49
+ from utils.plots import Annotator
50
+ from utils.torch_utils import select_device, smart_inference_mode
51
+
52
+
53
+ @smart_inference_mode()
54
+ def run(
55
+ weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
56
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
57
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
58
+ imgsz=(224, 224), # inference size (height, width)
59
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
60
+ view_img=False, # show results
61
+ save_txt=False, # save results to *.txt
62
+ nosave=False, # do not save images/videos
63
+ augment=False, # augmented inference
64
+ visualize=False, # visualize features
65
+ update=False, # update all models
66
+ project=ROOT / 'runs/predict-cls', # save results to project/name
67
+ name='exp', # save results to project/name
68
+ exist_ok=False, # existing project/name ok, do not increment
69
+ half=False, # use FP16 half-precision inference
70
+ dnn=False, # use OpenCV DNN for ONNX inference
71
+ vid_stride=1, # video frame-rate stride
72
+ ):
73
+ source = str(source)
74
+ save_img = not nosave and not source.endswith('.txt') # save inference images
75
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
76
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
77
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
78
+ screenshot = source.lower().startswith('screen')
79
+ if is_url and is_file:
80
+ source = check_file(source) # download
81
+
82
+ # Directories
83
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
84
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
85
+
86
+ # Load model
87
+ device = select_device(device)
88
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
89
+ stride, names, pt = model.stride, model.names, model.pt
90
+ imgsz = check_img_size(imgsz, s=stride) # check image size
91
+
92
+ # Dataloader
93
+ bs = 1 # batch_size
94
+ if webcam:
95
+ view_img = check_imshow(warn=True)
96
+ dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
97
+ bs = len(dataset)
98
+ elif screenshot:
99
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
100
+ else:
101
+ dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
102
+ vid_path, vid_writer = [None] * bs, [None] * bs
103
+
104
+ # Run inference
105
+ model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
106
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
107
+ for path, im, im0s, vid_cap, s in dataset:
108
+ with dt[0]:
109
+ im = torch.Tensor(im).to(model.device)
110
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
111
+ if len(im.shape) == 3:
112
+ im = im[None] # expand for batch dim
113
+
114
+ # Inference
115
+ with dt[1]:
116
+ results = model(im)
117
+
118
+ # Post-process
119
+ with dt[2]:
120
+ pred = F.softmax(results, dim=1) # probabilities
121
+
122
+ # Process predictions
123
+ for i, prob in enumerate(pred): # per image
124
+ seen += 1
125
+ if webcam: # batch_size >= 1
126
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
127
+ s += f'{i}: '
128
+ else:
129
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
130
+
131
+ p = Path(p) # to Path
132
+ save_path = str(save_dir / p.name) # im.jpg
133
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
134
+
135
+ s += '%gx%g ' % im.shape[2:] # print string
136
+ annotator = Annotator(im0, example=str(names), pil=True)
137
+
138
+ # Print results
139
+ top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
140
+ s += f"{', '.join(f'{names[j]} {prob[j]:.2f}' for j in top5i)}, "
141
+
142
+ # Write results
143
+ text = '\n'.join(f'{prob[j]:.2f} {names[j]}' for j in top5i)
144
+ if save_img or view_img: # Add bbox to image
145
+ annotator.text((32, 32), text, txt_color=(255, 255, 255))
146
+ if save_txt: # Write to file
147
+ with open(f'{txt_path}.txt', 'a') as f:
148
+ f.write(text + '\n')
149
+
150
+ # Stream results
151
+ im0 = annotator.result()
152
+ if view_img:
153
+ if platform.system() == 'Linux' and p not in windows:
154
+ windows.append(p)
155
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
156
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
157
+ cv2.imshow(str(p), im0)
158
+ cv2.waitKey(1) # 1 millisecond
159
+
160
+ # Save results (image with detections)
161
+ if save_img:
162
+ if dataset.mode == 'image':
163
+ cv2.imwrite(save_path, im0)
164
+ else: # 'video' or 'stream'
165
+ if vid_path[i] != save_path: # new video
166
+ vid_path[i] = save_path
167
+ if isinstance(vid_writer[i], cv2.VideoWriter):
168
+ vid_writer[i].release() # release previous video writer
169
+ if vid_cap: # video
170
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
171
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
172
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
173
+ else: # stream
174
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
175
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
176
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
177
+ vid_writer[i].write(im0)
178
+
179
+ # Print time (inference-only)
180
+ LOGGER.info(f"{s}{dt[1].dt * 1E3:.1f}ms")
181
+
182
+ # Print results
183
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
184
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
185
+ if save_txt or save_img:
186
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
187
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
188
+ if update:
189
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
190
+
191
+
192
+ def parse_opt():
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
195
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
196
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
197
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[224], help='inference size h,w')
198
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
199
+ parser.add_argument('--view-img', action='store_true', help='show results')
200
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
201
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
202
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
203
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
204
+ parser.add_argument('--update', action='store_true', help='update all models')
205
+ parser.add_argument('--project', default=ROOT / 'runs/predict-cls', help='save results to project/name')
206
+ parser.add_argument('--name', default='exp', help='save results to project/name')
207
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
208
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
209
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
210
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
211
+ opt = parser.parse_args()
212
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
213
+ print_args(vars(opt))
214
+ return opt
215
+
216
+
217
+ def main(opt):
218
+ check_requirements(exclude=('tensorboard', 'thop'))
219
+ run(**vars(opt))
220
+
221
+
222
+ if __name__ == "__main__":
223
+ opt = parse_opt()
224
+ main(opt)
yolov9/classify/train.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Train a YOLOv5 classifier model on a classification dataset
4
+
5
+ Usage - Single-GPU training:
6
+ $ python classify/train.py --model yolov5s-cls.pt --data imagenette160 --epochs 5 --img 224
7
+
8
+ Usage - Multi-GPU DDP training:
9
+ $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3
10
+
11
+ Datasets: --data mnist, fashion-mnist, cifar10, cifar100, imagenette, imagewoof, imagenet, or 'path/to/data'
12
+ YOLOv5-cls models: --model yolov5n-cls.pt, yolov5s-cls.pt, yolov5m-cls.pt, yolov5l-cls.pt, yolov5x-cls.pt
13
+ Torchvision models: --model resnet50, efficientnet_b0, etc. See https://pytorch.org/vision/stable/models.html
14
+ """
15
+
16
+ import argparse
17
+ import os
18
+ import subprocess
19
+ import sys
20
+ import time
21
+ from copy import deepcopy
22
+ from datetime import datetime
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torch.hub as hub
28
+ import torch.optim.lr_scheduler as lr_scheduler
29
+ import torchvision
30
+ from torch.cuda import amp
31
+ from tqdm import tqdm
32
+
33
+ FILE = Path(__file__).resolve()
34
+ ROOT = FILE.parents[1] # YOLOv5 root directory
35
+ if str(ROOT) not in sys.path:
36
+ sys.path.append(str(ROOT)) # add ROOT to PATH
37
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
38
+
39
+ from classify import val as validate
40
+ from models.experimental import attempt_load
41
+ from models.yolo import ClassificationModel, DetectionModel
42
+ from utils.dataloaders import create_classification_dataloader
43
+ from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_info, check_git_status,
44
+ check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save)
45
+ from utils.loggers import GenericLogger
46
+ from utils.plots import imshow_cls
47
+ from utils.torch_utils import (ModelEMA, model_info, reshape_classifier_output, select_device, smart_DDP,
48
+ smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first)
49
+
50
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
51
+ RANK = int(os.getenv('RANK', -1))
52
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
53
+ GIT_INFO = check_git_info()
54
+
55
+
56
+ def train(opt, device):
57
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
58
+ save_dir, data, bs, epochs, nw, imgsz, pretrained = \
59
+ opt.save_dir, Path(opt.data), opt.batch_size, opt.epochs, min(os.cpu_count() - 1, opt.workers), \
60
+ opt.imgsz, str(opt.pretrained).lower() == 'true'
61
+ cuda = device.type != 'cpu'
62
+
63
+ # Directories
64
+ wdir = save_dir / 'weights'
65
+ wdir.mkdir(parents=True, exist_ok=True) # make dir
66
+ last, best = wdir / 'last.pt', wdir / 'best.pt'
67
+
68
+ # Save run settings
69
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
70
+
71
+ # Logger
72
+ logger = GenericLogger(opt=opt, console_logger=LOGGER) if RANK in {-1, 0} else None
73
+
74
+ # Download Dataset
75
+ with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
76
+ data_dir = data if data.is_dir() else (DATASETS_DIR / data)
77
+ if not data_dir.is_dir():
78
+ LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
79
+ t = time.time()
80
+ if str(data) == 'imagenet':
81
+ subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
82
+ else:
83
+ url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip'
84
+ download(url, dir=data_dir.parent)
85
+ s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
86
+ LOGGER.info(s)
87
+
88
+ # Dataloaders
89
+ nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
90
+ trainloader = create_classification_dataloader(path=data_dir / 'train',
91
+ imgsz=imgsz,
92
+ batch_size=bs // WORLD_SIZE,
93
+ augment=True,
94
+ cache=opt.cache,
95
+ rank=LOCAL_RANK,
96
+ workers=nw)
97
+
98
+ test_dir = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
99
+ if RANK in {-1, 0}:
100
+ testloader = create_classification_dataloader(path=test_dir,
101
+ imgsz=imgsz,
102
+ batch_size=bs // WORLD_SIZE * 2,
103
+ augment=False,
104
+ cache=opt.cache,
105
+ rank=-1,
106
+ workers=nw)
107
+
108
+ # Model
109
+ with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
110
+ if Path(opt.model).is_file() or opt.model.endswith('.pt'):
111
+ model = attempt_load(opt.model, device='cpu', fuse=False)
112
+ elif opt.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
113
+ model = torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None)
114
+ else:
115
+ m = hub.list('ultralytics/yolov5') # + hub.list('pytorch/vision') # models
116
+ raise ModuleNotFoundError(f'--model {opt.model} not found. Available models are: \n' + '\n'.join(m))
117
+ if isinstance(model, DetectionModel):
118
+ LOGGER.warning("WARNING ⚠️ pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
119
+ model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
120
+ reshape_classifier_output(model, nc) # update class count
121
+ for m in model.modules():
122
+ if not pretrained and hasattr(m, 'reset_parameters'):
123
+ m.reset_parameters()
124
+ if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
125
+ m.p = opt.dropout # set dropout
126
+ for p in model.parameters():
127
+ p.requires_grad = True # for training
128
+ model = model.to(device)
129
+
130
+ # Info
131
+ if RANK in {-1, 0}:
132
+ model.names = trainloader.dataset.classes # attach class names
133
+ model.transforms = testloader.dataset.torch_transforms # attach inference transforms
134
+ model_info(model)
135
+ if opt.verbose:
136
+ LOGGER.info(model)
137
+ images, labels = next(iter(trainloader))
138
+ file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
139
+ logger.log_images(file, name='Train Examples')
140
+ logger.log_graph(model, imgsz) # log model
141
+
142
+ # Optimizer
143
+ optimizer = smart_optimizer(model, opt.optimizer, opt.lr0, momentum=0.9, decay=opt.decay)
144
+
145
+ # Scheduler
146
+ lrf = 0.01 # final lr (fraction of lr0)
147
+ # lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
148
+ lf = lambda x: (1 - x / epochs) * (1 - lrf) + lrf # linear
149
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
150
+ # scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr0, total_steps=epochs, pct_start=0.1,
151
+ # final_div_factor=1 / 25 / lrf)
152
+
153
+ # EMA
154
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
155
+
156
+ # DDP mode
157
+ if cuda and RANK != -1:
158
+ model = smart_DDP(model)
159
+
160
+ # Train
161
+ t0 = time.time()
162
+ criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
163
+ best_fitness = 0.0
164
+ scaler = amp.GradScaler(enabled=cuda)
165
+ val = test_dir.stem # 'val' or 'test'
166
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} test\n'
167
+ f'Using {nw * WORLD_SIZE} dataloader workers\n'
168
+ f"Logging results to {colorstr('bold', save_dir)}\n"
169
+ f'Starting {opt.model} training on {data} dataset with {nc} classes for {epochs} epochs...\n\n'
170
+ f"{'Epoch':>10}{'GPU_mem':>10}{'train_loss':>12}{f'{val}_loss':>12}{'top1_acc':>12}{'top5_acc':>12}")
171
+ for epoch in range(epochs): # loop over the dataset multiple times
172
+ tloss, vloss, fitness = 0.0, 0.0, 0.0 # train loss, val loss, fitness
173
+ model.train()
174
+ if RANK != -1:
175
+ trainloader.sampler.set_epoch(epoch)
176
+ pbar = enumerate(trainloader)
177
+ if RANK in {-1, 0}:
178
+ pbar = tqdm(enumerate(trainloader), total=len(trainloader), bar_format=TQDM_BAR_FORMAT)
179
+ for i, (images, labels) in pbar: # progress bar
180
+ images, labels = images.to(device, non_blocking=True), labels.to(device)
181
+
182
+ # Forward
183
+ with amp.autocast(enabled=cuda): # stability issues when enabled
184
+ loss = criterion(model(images), labels)
185
+
186
+ # Backward
187
+ scaler.scale(loss).backward()
188
+
189
+ # Optimize
190
+ scaler.unscale_(optimizer) # unscale gradients
191
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
192
+ scaler.step(optimizer)
193
+ scaler.update()
194
+ optimizer.zero_grad()
195
+ if ema:
196
+ ema.update(model)
197
+
198
+ if RANK in {-1, 0}:
199
+ # Print
200
+ tloss = (tloss * i + loss.item()) / (i + 1) # update mean losses
201
+ mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
202
+ pbar.desc = f"{f'{epoch + 1}/{epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
203
+
204
+ # Test
205
+ if i == len(pbar) - 1: # last batch
206
+ top1, top5, vloss = validate.run(model=ema.ema,
207
+ dataloader=testloader,
208
+ criterion=criterion,
209
+ pbar=pbar) # test accuracy, loss
210
+ fitness = top1 # define fitness as top1 accuracy
211
+
212
+ # Scheduler
213
+ scheduler.step()
214
+
215
+ # Log metrics
216
+ if RANK in {-1, 0}:
217
+ # Best fitness
218
+ if fitness > best_fitness:
219
+ best_fitness = fitness
220
+
221
+ # Log
222
+ metrics = {
223
+ "train/loss": tloss,
224
+ f"{val}/loss": vloss,
225
+ "metrics/accuracy_top1": top1,
226
+ "metrics/accuracy_top5": top5,
227
+ "lr/0": optimizer.param_groups[0]['lr']} # learning rate
228
+ logger.log_metrics(metrics, epoch)
229
+
230
+ # Save model
231
+ final_epoch = epoch + 1 == epochs
232
+ if (not opt.nosave) or final_epoch:
233
+ ckpt = {
234
+ 'epoch': epoch,
235
+ 'best_fitness': best_fitness,
236
+ 'model': deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
237
+ 'ema': None, # deepcopy(ema.ema).half(),
238
+ 'updates': ema.updates,
239
+ 'optimizer': None, # optimizer.state_dict(),
240
+ 'opt': vars(opt),
241
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
242
+ 'date': datetime.now().isoformat()}
243
+
244
+ # Save last, best and delete
245
+ torch.save(ckpt, last)
246
+ if best_fitness == fitness:
247
+ torch.save(ckpt, best)
248
+ del ckpt
249
+
250
+ # Train complete
251
+ if RANK in {-1, 0} and final_epoch:
252
+ LOGGER.info(f'\nTraining complete ({(time.time() - t0) / 3600:.3f} hours)'
253
+ f"\nResults saved to {colorstr('bold', save_dir)}"
254
+ f"\nPredict: python classify/predict.py --weights {best} --source im.jpg"
255
+ f"\nValidate: python classify/val.py --weights {best} --data {data_dir}"
256
+ f"\nExport: python export.py --weights {best} --include onnx"
257
+ f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{best}')"
258
+ f"\nVisualize: https://netron.app\n")
259
+
260
+ # Plot examples
261
+ images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
262
+ pred = torch.max(ema.ema(images.to(device)), 1)[1]
263
+ file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
264
+
265
+ # Log results
266
+ meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}
267
+ logger.log_images(file, name='Test Examples (true-predicted)', epoch=epoch)
268
+ logger.log_model(best, epochs, metadata=meta)
269
+
270
+
271
+ def parse_opt(known=False):
272
+ parser = argparse.ArgumentParser()
273
+ parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
274
+ parser.add_argument('--data', type=str, default='imagenette160', help='cifar10, cifar100, mnist, imagenet, ...')
275
+ parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
276
+ parser.add_argument('--batch-size', type=int, default=64, help='total batch size for all GPUs')
277
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
278
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
279
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
280
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
281
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
282
+ parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
283
+ parser.add_argument('--name', default='exp', help='save to project/name')
284
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
285
+ parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
286
+ parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
287
+ parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
288
+ parser.add_argument('--decay', type=float, default=5e-5, help='weight decay')
289
+ parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
290
+ parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
291
+ parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
292
+ parser.add_argument('--verbose', action='store_true', help='Verbose mode')
293
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
294
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
295
+ return parser.parse_known_args()[0] if known else parser.parse_args()
296
+
297
+
298
+ def main(opt):
299
+ # Checks
300
+ if RANK in {-1, 0}:
301
+ print_args(vars(opt))
302
+ check_git_status()
303
+ check_requirements()
304
+
305
+ # DDP mode
306
+ device = select_device(opt.device, batch_size=opt.batch_size)
307
+ if LOCAL_RANK != -1:
308
+ assert opt.batch_size != -1, 'AutoBatch is coming soon for classification, please pass a valid --batch-size'
309
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
310
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
311
+ torch.cuda.set_device(LOCAL_RANK)
312
+ device = torch.device('cuda', LOCAL_RANK)
313
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
314
+
315
+ # Parameters
316
+ opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
317
+
318
+ # Train
319
+ train(opt, device)
320
+
321
+
322
+ def run(**kwargs):
323
+ # Usage: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m')
324
+ opt = parse_opt(True)
325
+ for k, v in kwargs.items():
326
+ setattr(opt, k, v)
327
+ main(opt)
328
+ return opt
329
+
330
+
331
+ if __name__ == "__main__":
332
+ opt = parse_opt()
333
+ main(opt)
yolov9/classify/val.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Validate a trained YOLOv5 classification model on a classification dataset
4
+
5
+ Usage:
6
+ $ bash data/scripts/get_imagenet.sh --val # download ImageNet val split (6.3G, 50000 images)
7
+ $ python classify/val.py --weights yolov5m-cls.pt --data ../datasets/imagenet --img 224 # validate ImageNet
8
+
9
+ Usage - formats:
10
+ $ python classify/val.py --weights yolov5s-cls.pt # PyTorch
11
+ yolov5s-cls.torchscript # TorchScript
12
+ yolov5s-cls.onnx # ONNX Runtime or OpenCV DNN with --dnn
13
+ yolov5s-cls_openvino_model # OpenVINO
14
+ yolov5s-cls.engine # TensorRT
15
+ yolov5s-cls.mlmodel # CoreML (macOS-only)
16
+ yolov5s-cls_saved_model # TensorFlow SavedModel
17
+ yolov5s-cls.pb # TensorFlow GraphDef
18
+ yolov5s-cls.tflite # TensorFlow Lite
19
+ yolov5s-cls_edgetpu.tflite # TensorFlow Edge TPU
20
+ yolov5s-cls_paddle_model # PaddlePaddle
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ from tqdm import tqdm
30
+
31
+ FILE = Path(__file__).resolve()
32
+ ROOT = FILE.parents[1] # YOLOv5 root directory
33
+ if str(ROOT) not in sys.path:
34
+ sys.path.append(str(ROOT)) # add ROOT to PATH
35
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
36
+
37
+ from models.common import DetectMultiBackend
38
+ from utils.dataloaders import create_classification_dataloader
39
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_img_size, check_requirements, colorstr,
40
+ increment_path, print_args)
41
+ from utils.torch_utils import select_device, smart_inference_mode
42
+
43
+
44
+ @smart_inference_mode()
45
+ def run(
46
+ data=ROOT / '../datasets/mnist', # dataset dir
47
+ weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
48
+ batch_size=128, # batch size
49
+ imgsz=224, # inference size (pixels)
50
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
51
+ workers=8, # max dataloader workers (per RANK in DDP mode)
52
+ verbose=False, # verbose output
53
+ project=ROOT / 'runs/val-cls', # save to project/name
54
+ name='exp', # save to project/name
55
+ exist_ok=False, # existing project/name ok, do not increment
56
+ half=False, # use FP16 half-precision inference
57
+ dnn=False, # use OpenCV DNN for ONNX inference
58
+ model=None,
59
+ dataloader=None,
60
+ criterion=None,
61
+ pbar=None,
62
+ ):
63
+ # Initialize/load model and set device
64
+ training = model is not None
65
+ if training: # called by train.py
66
+ device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
67
+ half &= device.type != 'cpu' # half precision only supported on CUDA
68
+ model.half() if half else model.float()
69
+ else: # called directly
70
+ device = select_device(device, batch_size=batch_size)
71
+
72
+ # Directories
73
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
74
+ save_dir.mkdir(parents=True, exist_ok=True) # make dir
75
+
76
+ # Load model
77
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
78
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
79
+ imgsz = check_img_size(imgsz, s=stride) # check image size
80
+ half = model.fp16 # FP16 supported on limited backends with CUDA
81
+ if engine:
82
+ batch_size = model.batch_size
83
+ else:
84
+ device = model.device
85
+ if not (pt or jit):
86
+ batch_size = 1 # export.py models default to batch-size 1
87
+ LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
88
+
89
+ # Dataloader
90
+ data = Path(data)
91
+ test_dir = data / 'test' if (data / 'test').exists() else data / 'val' # data/test or data/val
92
+ dataloader = create_classification_dataloader(path=test_dir,
93
+ imgsz=imgsz,
94
+ batch_size=batch_size,
95
+ augment=False,
96
+ rank=-1,
97
+ workers=workers)
98
+
99
+ model.eval()
100
+ pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
101
+ n = len(dataloader) # number of batches
102
+ action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
103
+ desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
104
+ bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
105
+ with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
106
+ for images, labels in bar:
107
+ with dt[0]:
108
+ images, labels = images.to(device, non_blocking=True), labels.to(device)
109
+
110
+ with dt[1]:
111
+ y = model(images)
112
+
113
+ with dt[2]:
114
+ pred.append(y.argsort(1, descending=True)[:, :5])
115
+ targets.append(labels)
116
+ if criterion:
117
+ loss += criterion(y, labels)
118
+
119
+ loss /= n
120
+ pred, targets = torch.cat(pred), torch.cat(targets)
121
+ correct = (targets[:, None] == pred).float()
122
+ acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
123
+ top1, top5 = acc.mean(0).tolist()
124
+
125
+ if pbar:
126
+ pbar.desc = f"{pbar.desc[:-36]}{loss:>12.3g}{top1:>12.3g}{top5:>12.3g}"
127
+ if verbose: # all classes
128
+ LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
129
+ LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
130
+ for i, c in model.names.items():
131
+ aci = acc[targets == i]
132
+ top1i, top5i = aci.mean(0).tolist()
133
+ LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
134
+
135
+ # Print results
136
+ t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
137
+ shape = (1, 3, imgsz, imgsz)
138
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
139
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
140
+
141
+ return top1, top5, loss
142
+
143
+
144
+ def parse_opt():
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('--data', type=str, default=ROOT / '../datasets/mnist', help='dataset path')
147
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model.pt path(s)')
148
+ parser.add_argument('--batch-size', type=int, default=128, help='batch size')
149
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='inference size (pixels)')
150
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
151
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
152
+ parser.add_argument('--verbose', nargs='?', const=True, default=True, help='verbose output')
153
+ parser.add_argument('--project', default=ROOT / 'runs/val-cls', help='save to project/name')
154
+ parser.add_argument('--name', default='exp', help='save to project/name')
155
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
156
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
157
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
158
+ opt = parser.parse_args()
159
+ print_args(vars(opt))
160
+ return opt
161
+
162
+
163
+ def main(opt):
164
+ check_requirements(exclude=('tensorboard', 'thop'))
165
+ run(**vars(opt))
166
+
167
+
168
+ if __name__ == "__main__":
169
+ opt = parse_opt()
170
+ main(opt)
yolov9/data/coco.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path: ../datasets/coco # dataset root dir
2
+ train: train2017.txt # train images (relative to 'path') 118287 images
3
+ val: val2017.txt # val images (relative to 'path') 5000 images
4
+ test: test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
5
+
6
+ # Classes
7
+ names:
8
+ 0: person
9
+ 1: bicycle
10
+ 2: car
11
+ 3: motorcycle
12
+ 4: airplane
13
+ 5: bus
14
+ 6: train
15
+ 7: truck
16
+ 8: boat
17
+ 9: traffic light
18
+ 10: fire hydrant
19
+ 11: stop sign
20
+ 12: parking meter
21
+ 13: bench
22
+ 14: bird
23
+ 15: cat
24
+ 16: dog
25
+ 17: horse
26
+ 18: sheep
27
+ 19: cow
28
+ 20: elephant
29
+ 21: bear
30
+ 22: zebra
31
+ 23: giraffe
32
+ 24: backpack
33
+ 25: umbrella
34
+ 26: handbag
35
+ 27: tie
36
+ 28: suitcase
37
+ 29: frisbee
38
+ 30: skis
39
+ 31: snowboard
40
+ 32: sports ball
41
+ 33: kite
42
+ 34: baseball bat
43
+ 35: baseball glove
44
+ 36: skateboard
45
+ 37: surfboard
46
+ 38: tennis racket
47
+ 39: bottle
48
+ 40: wine glass
49
+ 41: cup
50
+ 42: fork
51
+ 43: knife
52
+ 44: spoon
53
+ 45: bowl
54
+ 46: banana
55
+ 47: apple
56
+ 48: sandwich
57
+ 49: orange
58
+ 50: broccoli
59
+ 51: carrot
60
+ 52: hot dog
61
+ 53: pizza
62
+ 54: donut
63
+ 55: cake
64
+ 56: chair
65
+ 57: couch
66
+ 58: potted plant
67
+ 59: bed
68
+ 60: dining table
69
+ 61: toilet
70
+ 62: tv
71
+ 63: laptop
72
+ 64: mouse
73
+ 65: remote
74
+ 66: keyboard
75
+ 67: cell phone
76
+ 68: microwave
77
+ 69: oven
78
+ 70: toaster
79
+ 71: sink
80
+ 72: refrigerator
81
+ 73: book
82
+ 74: clock
83
+ 75: vase
84
+ 76: scissors
85
+ 77: teddy bear
86
+ 78: hair drier
87
+ 79: toothbrush
88
+
89
+
90
+ # Download script/URL (optional)
91
+ download: |
92
+ from utils.general import download, Path
93
+
94
+
95
+ # Download labels
96
+ #segments = True # segment or box labels
97
+ #dir = Path(yaml['path']) # dataset root dir
98
+ #url = 'https://github.com/WongKinYiu/yolov7/releases/download/v0.1/'
99
+ #urls = [url + ('coco2017labels-segments.zip' if segments else 'coco2017labels.zip')] # labels
100
+ #download(urls, dir=dir.parent)
101
+
102
+ # Download data
103
+ #urls = ['http://images.cocodataset.org/zips/train2017.zip', # 19G, 118k images
104
+ # 'http://images.cocodataset.org/zips/val2017.zip', # 1G, 5k images
105
+ # 'http://images.cocodataset.org/zips/test2017.zip'] # 7G, 41k images (optional)
106
+ #download(urls, dir=dir / 'images', threads=3)
yolov9/data/hyps/hyp.scratch-high.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
2
+ lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
3
+ momentum: 0.937 # SGD momentum/Adam beta1
4
+ weight_decay: 0.0005 # optimizer weight decay 5e-4
5
+ warmup_epochs: 3.0 # warmup epochs (fractions ok)
6
+ warmup_momentum: 0.8 # warmup initial momentum
7
+ warmup_bias_lr: 0.1 # warmup initial bias lr
8
+ box: 7.5 # box loss gain
9
+ cls: 0.5 # cls loss gain
10
+ cls_pw: 1.0 # cls BCELoss positive_weight
11
+ dfl: 0.7 # obj loss gain (scale with pixels)
12
+ obj_pw: 1.0 # obj BCELoss positive_weight
13
+ dfl: 1.5 # dfl loss gain
14
+ iou_t: 0.20 # IoU training threshold
15
+ anchor_t: 5.0 # anchor-multiple threshold
16
+ # anchors: 3 # anchors per output layer (0 to ignore)
17
+ fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
18
+ hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
19
+ hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
20
+ hsv_v: 0.4 # image HSV-Value augmentation (fraction)
21
+ degrees: 0.0 # image rotation (+/- deg)
22
+ translate: 0.1 # image translation (+/- fraction)
23
+ scale: 0.9 # image scale (+/- gain)
24
+ shear: 0.0 # image shear (+/- deg)
25
+ perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
26
+ flipud: 0.0 # image flip up-down (probability)
27
+ fliplr: 0.5 # image flip left-right (probability)
28
+ mosaic: 1.0 # image mosaic (probability)
29
+ mixup: 0.15 # image mixup (probability)
30
+ copy_paste: 0.3 # segment copy-paste (probability)
yolov9/detect.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
19
+ from utils.plots import Annotator, colors, save_one_box
20
+ from utils.torch_utils import select_device, smart_inference_mode
21
+
22
+
23
+ @smart_inference_mode()
24
+ def run(
25
+ weights=ROOT / 'yolo.pt', # model path or triton URL
26
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
27
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
28
+ imgsz=(640, 640), # inference size (height, width)
29
+ conf_thres=0.25, # confidence threshold
30
+ iou_thres=0.45, # NMS IOU threshold
31
+ max_det=1000, # maximum detections per image
32
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
33
+ view_img=False, # show results
34
+ save_txt=False, # save results to *.txt
35
+ save_conf=False, # save confidences in --save-txt labels
36
+ save_crop=False, # save cropped prediction boxes
37
+ nosave=False, # do not save images/videos
38
+ classes=None, # filter by class: --class 0, or --class 0 2 3
39
+ agnostic_nms=False, # class-agnostic NMS
40
+ augment=False, # augmented inference
41
+ visualize=False, # visualize features
42
+ update=False, # update all models
43
+ project=ROOT / 'runs/detect', # save results to project/name
44
+ name='exp', # save results to project/name
45
+ exist_ok=False, # existing project/name ok, do not increment
46
+ line_thickness=3, # bounding box thickness (pixels)
47
+ hide_labels=False, # hide labels
48
+ hide_conf=False, # hide confidences
49
+ half=False, # use FP16 half-precision inference
50
+ dnn=False, # use OpenCV DNN for ONNX inference
51
+ vid_stride=1, # video frame-rate stride
52
+ ):
53
+ source = str(source)
54
+ save_img = not nosave and not source.endswith('.txt') # save inference images
55
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
56
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
57
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
58
+ screenshot = source.lower().startswith('screen')
59
+ if is_url and is_file:
60
+ source = check_file(source) # download
61
+
62
+ # Directories
63
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
64
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
65
+
66
+ # Load model
67
+ device = select_device(device)
68
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
69
+ stride, names, pt = model.stride, model.names, model.pt
70
+ imgsz = check_img_size(imgsz, s=stride) # check image size
71
+
72
+ # Dataloader
73
+ bs = 1 # batch_size
74
+ if webcam:
75
+ view_img = check_imshow(warn=True)
76
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
77
+ bs = len(dataset)
78
+ elif screenshot:
79
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
80
+ else:
81
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
82
+ vid_path, vid_writer = [None] * bs, [None] * bs
83
+
84
+ # Run inference
85
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
86
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
87
+ for path, im, im0s, vid_cap, s in dataset:
88
+ with dt[0]:
89
+ im = torch.from_numpy(im).to(model.device)
90
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
91
+ im /= 255 # 0 - 255 to 0.0 - 1.0
92
+ if len(im.shape) == 3:
93
+ im = im[None] # expand for batch dim
94
+
95
+ # Inference
96
+ with dt[1]:
97
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
98
+ pred = model(im, augment=augment, visualize=visualize)
99
+
100
+ # NMS
101
+ with dt[2]:
102
+ pred = pred[0][1] if isinstance(pred[0], list) else pred[0]
103
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
104
+
105
+ # Second-stage classifier (optional)
106
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
107
+
108
+ # Process predictions
109
+ for i, det in enumerate(pred): # per image
110
+ seen += 1
111
+ if webcam: # batch_size >= 1
112
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
113
+ s += f'{i}: '
114
+ else:
115
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
116
+
117
+ p = Path(p) # to Path
118
+ save_path = str(save_dir / p.name) # im.jpg
119
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
120
+ s += '%gx%g ' % im.shape[2:] # print string
121
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
122
+ imc = im0.copy() if save_crop else im0 # for save_crop
123
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
124
+ if len(det):
125
+ # Rescale boxes from img_size to im0 size
126
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
127
+
128
+ # Print results
129
+ for c in det[:, 5].unique():
130
+ n = (det[:, 5] == c).sum() # detections per class
131
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
132
+
133
+ # Write results
134
+ for *xyxy, conf, cls in reversed(det):
135
+ if save_txt: # Write to file
136
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
137
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
138
+ with open(f'{txt_path}.txt', 'a') as f:
139
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
140
+
141
+ if save_img or save_crop or view_img: # Add bbox to image
142
+ c = int(cls) # integer class
143
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
144
+ annotator.box_label(xyxy, label, color=colors(c, True))
145
+ if save_crop:
146
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
147
+
148
+ # Stream results
149
+ im0 = annotator.result()
150
+ if view_img:
151
+ if platform.system() == 'Linux' and p not in windows:
152
+ windows.append(p)
153
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
154
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
155
+ cv2.imshow(str(p), im0)
156
+ cv2.waitKey(1) # 1 millisecond
157
+
158
+ # Save results (image with detections)
159
+ if save_img:
160
+ if dataset.mode == 'image':
161
+ cv2.imwrite(save_path, im0)
162
+ else: # 'video' or 'stream'
163
+ if vid_path[i] != save_path: # new video
164
+ vid_path[i] = save_path
165
+ if isinstance(vid_writer[i], cv2.VideoWriter):
166
+ vid_writer[i].release() # release previous video writer
167
+ if vid_cap: # video
168
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
169
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
170
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
171
+ else: # stream
172
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
173
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
174
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
175
+ vid_writer[i].write(im0)
176
+
177
+ # Print time (inference-only)
178
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
179
+
180
+ # Print results
181
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
182
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
183
+ if save_txt or save_img:
184
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
185
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
186
+ if update:
187
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
188
+
189
+
190
+ def parse_opt():
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model path or triton URL')
193
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
194
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
195
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
196
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
197
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
198
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
199
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
200
+ parser.add_argument('--view-img', action='store_true', help='show results')
201
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
202
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
203
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
204
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
205
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
206
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
207
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
208
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
209
+ parser.add_argument('--update', action='store_true', help='update all models')
210
+ parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
211
+ parser.add_argument('--name', default='exp', help='save results to project/name')
212
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
213
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
214
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
215
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
216
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
217
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
218
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
219
+ opt = parser.parse_args()
220
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
221
+ print_args(vars(opt))
222
+ return opt
223
+
224
+
225
+ def main(opt):
226
+ check_requirements(exclude=('tensorboard', 'thop'))
227
+ run(**vars(opt))
228
+
229
+
230
+ if __name__ == "__main__":
231
+ opt = parse_opt()
232
+ main(opt)
yolov9/export.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import json
4
+ import os
5
+ import platform
6
+ import re
7
+ import subprocess
8
+ import sys
9
+ import time
10
+ import warnings
11
+ from pathlib import Path
12
+
13
+ import pandas as pd
14
+ import torch
15
+ from torch.utils.mobile_optimizer import optimize_for_mobile
16
+
17
+ FILE = Path(__file__).resolve()
18
+ ROOT = FILE.parents[0] # YOLO root directory
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.append(str(ROOT)) # add ROOT to PATH
21
+ if platform.system() != 'Windows':
22
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
23
+
24
+ from models.experimental import attempt_load
25
+ from models.yolo import ClassificationModel, Detect, DetectionModel, SegmentationModel
26
+ from utils.dataloaders import LoadImages
27
+ from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
28
+ check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
29
+ from utils.torch_utils import select_device, smart_inference_mode
30
+
31
+ MACOS = platform.system() == 'Darwin' # macOS environment
32
+
33
+
34
+ def export_formats():
35
+ # YOLO export formats
36
+ x = [
37
+ ['PyTorch', '-', '.pt', True, True],
38
+ ['TorchScript', 'torchscript', '.torchscript', True, True],
39
+ ['ONNX', 'onnx', '.onnx', True, True],
40
+ ['OpenVINO', 'openvino', '_openvino_model', True, False],
41
+ ['TensorRT', 'engine', '.engine', False, True],
42
+ ['CoreML', 'coreml', '.mlmodel', True, False],
43
+ ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
44
+ ['TensorFlow GraphDef', 'pb', '.pb', True, True],
45
+ ['TensorFlow Lite', 'tflite', '.tflite', True, False],
46
+ ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
47
+ ['TensorFlow.js', 'tfjs', '_web_model', False, False],
48
+ ['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
49
+ return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
50
+
51
+
52
+ def try_export(inner_func):
53
+ # YOLO export decorator, i..e @try_export
54
+ inner_args = get_default_args(inner_func)
55
+
56
+ def outer_func(*args, **kwargs):
57
+ prefix = inner_args['prefix']
58
+ try:
59
+ with Profile() as dt:
60
+ f, model = inner_func(*args, **kwargs)
61
+ LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
62
+ return f, model
63
+ except Exception as e:
64
+ LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
65
+ return None, None
66
+
67
+ return outer_func
68
+
69
+
70
+ @try_export
71
+ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
72
+ # YOLO TorchScript model export
73
+ LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
74
+ f = file.with_suffix('.torchscript')
75
+
76
+ ts = torch.jit.trace(model, im, strict=False)
77
+ d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
78
+ extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
79
+ if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
80
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
81
+ else:
82
+ ts.save(str(f), _extra_files=extra_files)
83
+ return f, None
84
+
85
+
86
+ @try_export
87
+ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
88
+ # YOLO ONNX export
89
+ check_requirements('onnx')
90
+ import onnx
91
+
92
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
93
+ f = file.with_suffix('.onnx')
94
+
95
+ output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
96
+ if dynamic:
97
+ dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
98
+ if isinstance(model, SegmentationModel):
99
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
100
+ dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
101
+ elif isinstance(model, DetectionModel):
102
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
103
+
104
+ torch.onnx.export(
105
+ model.cpu() if dynamic else model, # --dynamic only compatible with cpu
106
+ im.cpu() if dynamic else im,
107
+ f,
108
+ verbose=False,
109
+ opset_version=opset,
110
+ do_constant_folding=True,
111
+ input_names=['images'],
112
+ output_names=output_names,
113
+ dynamic_axes=dynamic or None)
114
+
115
+ # Checks
116
+ model_onnx = onnx.load(f) # load onnx model
117
+ onnx.checker.check_model(model_onnx) # check onnx model
118
+
119
+ # Metadata
120
+ d = {'stride': int(max(model.stride)), 'names': model.names}
121
+ for k, v in d.items():
122
+ meta = model_onnx.metadata_props.add()
123
+ meta.key, meta.value = k, str(v)
124
+ onnx.save(model_onnx, f)
125
+
126
+ # Simplify
127
+ if simplify:
128
+ try:
129
+ cuda = torch.cuda.is_available()
130
+ check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
131
+ import onnxsim
132
+
133
+ LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
134
+ model_onnx, check = onnxsim.simplify(model_onnx)
135
+ assert check, 'assert check failed'
136
+ onnx.save(model_onnx, f)
137
+ except Exception as e:
138
+ LOGGER.info(f'{prefix} simplifier failure: {e}')
139
+ return f, model_onnx
140
+
141
+
142
+ @try_export
143
+ def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
144
+ # YOLO OpenVINO export
145
+ check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
146
+ import openvino.inference_engine as ie
147
+
148
+ LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
149
+ f = str(file).replace('.pt', f'_openvino_model{os.sep}')
150
+
151
+ cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
152
+ subprocess.run(cmd.split(), check=True, env=os.environ) # export
153
+ yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
154
+ return f, None
155
+
156
+
157
+ @try_export
158
+ def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
159
+ # YOLO Paddle export
160
+ check_requirements(('paddlepaddle', 'x2paddle'))
161
+ import x2paddle
162
+ from x2paddle.convert import pytorch2paddle
163
+
164
+ LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
165
+ f = str(file).replace('.pt', f'_paddle_model{os.sep}')
166
+
167
+ pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
168
+ yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
169
+ return f, None
170
+
171
+
172
+ @try_export
173
+ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
174
+ # YOLO CoreML export
175
+ check_requirements('coremltools')
176
+ import coremltools as ct
177
+
178
+ LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
179
+ f = file.with_suffix('.mlmodel')
180
+
181
+ ts = torch.jit.trace(model, im, strict=False) # TorchScript model
182
+ ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
183
+ bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
184
+ if bits < 32:
185
+ if MACOS: # quantization only supported on macOS
186
+ with warnings.catch_warnings():
187
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
188
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
189
+ else:
190
+ print(f'{prefix} quantization only supported on macOS, skipping...')
191
+ ct_model.save(f)
192
+ return f, ct_model
193
+
194
+
195
+ @try_export
196
+ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
197
+ # YOLO TensorRT export https://developer.nvidia.com/tensorrt
198
+ assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
199
+ try:
200
+ import tensorrt as trt
201
+ except Exception:
202
+ if platform.system() == 'Linux':
203
+ check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
204
+ import tensorrt as trt
205
+
206
+ if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
207
+ grid = model.model[-1].anchor_grid
208
+ model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
209
+ export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
210
+ model.model[-1].anchor_grid = grid
211
+ else: # TensorRT >= 8
212
+ check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
213
+ export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
214
+ onnx = file.with_suffix('.onnx')
215
+
216
+ LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
217
+ assert onnx.exists(), f'failed to export ONNX file: {onnx}'
218
+ f = file.with_suffix('.engine') # TensorRT engine file
219
+ logger = trt.Logger(trt.Logger.INFO)
220
+ if verbose:
221
+ logger.min_severity = trt.Logger.Severity.VERBOSE
222
+
223
+ builder = trt.Builder(logger)
224
+ config = builder.create_builder_config()
225
+ config.max_workspace_size = workspace * 1 << 30
226
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
227
+
228
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
229
+ network = builder.create_network(flag)
230
+ parser = trt.OnnxParser(network, logger)
231
+ if not parser.parse_from_file(str(onnx)):
232
+ raise RuntimeError(f'failed to load ONNX file: {onnx}')
233
+
234
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
235
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
236
+ for inp in inputs:
237
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
238
+ for out in outputs:
239
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
240
+
241
+ if dynamic:
242
+ if im.shape[0] <= 1:
243
+ LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
244
+ profile = builder.create_optimization_profile()
245
+ for inp in inputs:
246
+ profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
247
+ config.add_optimization_profile(profile)
248
+
249
+ LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
250
+ if builder.platform_has_fast_fp16 and half:
251
+ config.set_flag(trt.BuilderFlag.FP16)
252
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
253
+ t.write(engine.serialize())
254
+ return f, None
255
+
256
+
257
+ @try_export
258
+ def export_saved_model(model,
259
+ im,
260
+ file,
261
+ dynamic,
262
+ tf_nms=False,
263
+ agnostic_nms=False,
264
+ topk_per_class=100,
265
+ topk_all=100,
266
+ iou_thres=0.45,
267
+ conf_thres=0.25,
268
+ keras=False,
269
+ prefix=colorstr('TensorFlow SavedModel:')):
270
+ # YOLO TensorFlow SavedModel export
271
+ try:
272
+ import tensorflow as tf
273
+ except Exception:
274
+ check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
275
+ import tensorflow as tf
276
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
277
+
278
+ from models.tf import TFModel
279
+
280
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
281
+ f = str(file).replace('.pt', '_saved_model')
282
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
283
+
284
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
285
+ im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
286
+ _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
287
+ inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
288
+ outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
289
+ keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
290
+ keras_model.trainable = False
291
+ keras_model.summary()
292
+ if keras:
293
+ keras_model.save(f, save_format='tf')
294
+ else:
295
+ spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
296
+ m = tf.function(lambda x: keras_model(x)) # full model
297
+ m = m.get_concrete_function(spec)
298
+ frozen_func = convert_variables_to_constants_v2(m)
299
+ tfm = tf.Module()
300
+ tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
301
+ tfm.__call__(im)
302
+ tf.saved_model.save(tfm,
303
+ f,
304
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
305
+ tf.__version__, '2.6') else tf.saved_model.SaveOptions())
306
+ return f, keras_model
307
+
308
+
309
+ @try_export
310
+ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
311
+ # YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
312
+ import tensorflow as tf
313
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
314
+
315
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
316
+ f = file.with_suffix('.pb')
317
+
318
+ m = tf.function(lambda x: keras_model(x)) # full model
319
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
320
+ frozen_func = convert_variables_to_constants_v2(m)
321
+ frozen_func.graph.as_graph_def()
322
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
323
+ return f, None
324
+
325
+
326
+ @try_export
327
+ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
328
+ # YOLOv5 TensorFlow Lite export
329
+ import tensorflow as tf
330
+
331
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
332
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
333
+ f = str(file).replace('.pt', '-fp16.tflite')
334
+
335
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
336
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
337
+ converter.target_spec.supported_types = [tf.float16]
338
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
339
+ if int8:
340
+ from models.tf import representative_dataset_gen
341
+ dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
342
+ converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
343
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
344
+ converter.target_spec.supported_types = []
345
+ converter.inference_input_type = tf.uint8 # or tf.int8
346
+ converter.inference_output_type = tf.uint8 # or tf.int8
347
+ converter.experimental_new_quantizer = True
348
+ f = str(file).replace('.pt', '-int8.tflite')
349
+ if nms or agnostic_nms:
350
+ converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
351
+
352
+ tflite_model = converter.convert()
353
+ open(f, "wb").write(tflite_model)
354
+ return f, None
355
+
356
+
357
+ @try_export
358
+ def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
359
+ # YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
360
+ cmd = 'edgetpu_compiler --version'
361
+ help_url = 'https://coral.ai/docs/edgetpu/compiler/'
362
+ assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
363
+ if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
364
+ LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
365
+ sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
366
+ for c in (
367
+ 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
368
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
369
+ 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
370
+ subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
371
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
372
+
373
+ LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
374
+ f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
375
+ f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
376
+
377
+ cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
378
+ subprocess.run(cmd.split(), check=True)
379
+ return f, None
380
+
381
+
382
+ @try_export
383
+ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
384
+ # YOLO TensorFlow.js export
385
+ check_requirements('tensorflowjs')
386
+ import tensorflowjs as tfjs
387
+
388
+ LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
389
+ f = str(file).replace('.pt', '_web_model') # js dir
390
+ f_pb = file.with_suffix('.pb') # *.pb path
391
+ f_json = f'{f}/model.json' # *.json path
392
+
393
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
394
+ f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
395
+ subprocess.run(cmd.split())
396
+
397
+ json = Path(f_json).read_text()
398
+ with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
399
+ subst = re.sub(
400
+ r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
401
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
402
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
403
+ r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
404
+ r'"Identity_1": {"name": "Identity_1"}, '
405
+ r'"Identity_2": {"name": "Identity_2"}, '
406
+ r'"Identity_3": {"name": "Identity_3"}}}', json)
407
+ j.write(subst)
408
+ return f, None
409
+
410
+
411
+ def add_tflite_metadata(file, metadata, num_outputs):
412
+ # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
413
+ with contextlib.suppress(ImportError):
414
+ # check_requirements('tflite_support')
415
+ from tflite_support import flatbuffers
416
+ from tflite_support import metadata as _metadata
417
+ from tflite_support import metadata_schema_py_generated as _metadata_fb
418
+
419
+ tmp_file = Path('/tmp/meta.txt')
420
+ with open(tmp_file, 'w') as meta_f:
421
+ meta_f.write(str(metadata))
422
+
423
+ model_meta = _metadata_fb.ModelMetadataT()
424
+ label_file = _metadata_fb.AssociatedFileT()
425
+ label_file.name = tmp_file.name
426
+ model_meta.associatedFiles = [label_file]
427
+
428
+ subgraph = _metadata_fb.SubGraphMetadataT()
429
+ subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
430
+ subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
431
+ model_meta.subgraphMetadata = [subgraph]
432
+
433
+ b = flatbuffers.Builder(0)
434
+ b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
435
+ metadata_buf = b.Output()
436
+
437
+ populator = _metadata.MetadataPopulator.with_model_file(file)
438
+ populator.load_metadata_buffer(metadata_buf)
439
+ populator.load_associated_files([str(tmp_file)])
440
+ populator.populate()
441
+ tmp_file.unlink()
442
+
443
+
444
+ @smart_inference_mode()
445
+ def run(
446
+ data=ROOT / 'data/coco.yaml', # 'dataset.yaml path'
447
+ weights=ROOT / 'yolo.pt', # weights path
448
+ imgsz=(640, 640), # image (height, width)
449
+ batch_size=1, # batch size
450
+ device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
451
+ include=('torchscript', 'onnx'), # include formats
452
+ half=False, # FP16 half-precision export
453
+ inplace=False, # set YOLO Detect() inplace=True
454
+ keras=False, # use Keras
455
+ optimize=False, # TorchScript: optimize for mobile
456
+ int8=False, # CoreML/TF INT8 quantization
457
+ dynamic=False, # ONNX/TF/TensorRT: dynamic axes
458
+ simplify=False, # ONNX: simplify model
459
+ opset=12, # ONNX: opset version
460
+ verbose=False, # TensorRT: verbose log
461
+ workspace=4, # TensorRT: workspace size (GB)
462
+ nms=False, # TF: add NMS to model
463
+ agnostic_nms=False, # TF: add agnostic NMS to model
464
+ topk_per_class=100, # TF.js NMS: topk per class to keep
465
+ topk_all=100, # TF.js NMS: topk for all classes to keep
466
+ iou_thres=0.45, # TF.js NMS: IoU threshold
467
+ conf_thres=0.25, # TF.js NMS: confidence threshold
468
+ ):
469
+ t = time.time()
470
+ include = [x.lower() for x in include] # to lowercase
471
+ fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
472
+ flags = [x in include for x in fmts]
473
+ assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
474
+ jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
475
+ file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
476
+
477
+ # Load PyTorch model
478
+ device = select_device(device)
479
+ if half:
480
+ assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
481
+ assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
482
+ model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
483
+
484
+ # Checks
485
+ imgsz *= 2 if len(imgsz) == 1 else 1 # expand
486
+ if optimize:
487
+ assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
488
+
489
+ # Input
490
+ gs = int(max(model.stride)) # grid size (max stride)
491
+ imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
492
+ im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
493
+
494
+ # Update model
495
+ model.eval()
496
+ for k, m in model.named_modules():
497
+ if isinstance(m, (Detect, V6Detect)):
498
+ m.inplace = inplace
499
+ m.dynamic = dynamic
500
+ m.export = True
501
+
502
+ for _ in range(2):
503
+ y = model(im) # dry runs
504
+ if half and not coreml:
505
+ im, model = im.half(), model.half() # to FP16
506
+ shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
507
+ metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
508
+ LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
509
+
510
+ # Exports
511
+ f = [''] * len(fmts) # exported filenames
512
+ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
513
+ if jit: # TorchScript
514
+ f[0], _ = export_torchscript(model, im, file, optimize)
515
+ if engine: # TensorRT required before ONNX
516
+ f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
517
+ if onnx or xml: # OpenVINO requires ONNX
518
+ f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
519
+ if xml: # OpenVINO
520
+ f[3], _ = export_openvino(file, metadata, half)
521
+ if coreml: # CoreML
522
+ f[4], _ = export_coreml(model, im, file, int8, half)
523
+ if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
524
+ assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
525
+ assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
526
+ f[5], s_model = export_saved_model(model.cpu(),
527
+ im,
528
+ file,
529
+ dynamic,
530
+ tf_nms=nms or agnostic_nms or tfjs,
531
+ agnostic_nms=agnostic_nms or tfjs,
532
+ topk_per_class=topk_per_class,
533
+ topk_all=topk_all,
534
+ iou_thres=iou_thres,
535
+ conf_thres=conf_thres,
536
+ keras=keras)
537
+ if pb or tfjs: # pb prerequisite to tfjs
538
+ f[6], _ = export_pb(s_model, file)
539
+ if tflite or edgetpu:
540
+ f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
541
+ if edgetpu:
542
+ f[8], _ = export_edgetpu(file)
543
+ add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
544
+ if tfjs:
545
+ f[9], _ = export_tfjs(file)
546
+ if paddle: # PaddlePaddle
547
+ f[10], _ = export_paddle(model, im, file, metadata)
548
+
549
+ # Finish
550
+ f = [str(x) for x in f if x] # filter out '' and None
551
+ if any(f):
552
+ cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
553
+ dir = Path('segment' if seg else 'classify' if cls else '')
554
+ h = '--half' if half else '' # --half FP16 inference arg
555
+ s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \
556
+ "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''
557
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
558
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
559
+ f"\nDetect: python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"
560
+ f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}"
561
+ f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}"
562
+ f"\nVisualize: https://netron.app")
563
+ return f # return list of exported files/dirs
564
+
565
+
566
+ def parse_opt():
567
+ parser = argparse.ArgumentParser()
568
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
569
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model.pt path(s)')
570
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
571
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
572
+ parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
573
+ parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
574
+ parser.add_argument('--inplace', action='store_true', help='set YOLO Detect() inplace=True')
575
+ parser.add_argument('--keras', action='store_true', help='TF: use Keras')
576
+ parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
577
+ parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
578
+ parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
579
+ parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
580
+ parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
581
+ parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
582
+ parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
583
+ parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
584
+ parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
585
+ parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
586
+ parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
587
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
588
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
589
+ parser.add_argument(
590
+ '--include',
591
+ nargs='+',
592
+ default=['torchscript'],
593
+ help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')
594
+ opt = parser.parse_args()
595
+ print_args(vars(opt))
596
+ return opt
597
+
598
+
599
+ def main(opt):
600
+ for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
601
+ run(**vars(opt))
602
+
603
+
604
+ if __name__ == "__main__":
605
+ opt = parse_opt()
606
+ main(opt)
yolov9/figure/performance.png ADDED
yolov9/hubconf.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
5
+ """Creates or loads a YOLO model
6
+
7
+ Arguments:
8
+ name (str): model name 'yolov3' or path 'path/to/best.pt'
9
+ pretrained (bool): load pretrained weights into the model
10
+ channels (int): number of input channels
11
+ classes (int): number of model classes
12
+ autoshape (bool): apply YOLO .autoshape() wrapper to model
13
+ verbose (bool): print all information to screen
14
+ device (str, torch.device, None): device to use for model parameters
15
+
16
+ Returns:
17
+ YOLO model
18
+ """
19
+ from pathlib import Path
20
+
21
+ from models.common import AutoShape, DetectMultiBackend
22
+ from models.experimental import attempt_load
23
+ from models.yolo import ClassificationModel, DetectionModel, SegmentationModel
24
+ from utils.downloads import attempt_download
25
+ from utils.general import LOGGER, check_requirements, intersect_dicts, logging
26
+ from utils.torch_utils import select_device
27
+
28
+ if not verbose:
29
+ LOGGER.setLevel(logging.WARNING)
30
+ check_requirements(exclude=('opencv-python', 'tensorboard', 'thop'))
31
+ name = Path(name)
32
+ path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
33
+ try:
34
+ device = select_device(device)
35
+ if pretrained and channels == 3 and classes == 80:
36
+ try:
37
+ model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
38
+ if autoshape:
39
+ if model.pt and isinstance(model.model, ClassificationModel):
40
+ LOGGER.warning('WARNING ⚠️ YOLO ClassificationModel is not yet AutoShape compatible. '
41
+ 'You must pass torch tensors in BCHW to this model, i.e. shape(1,3,224,224).')
42
+ elif model.pt and isinstance(model.model, SegmentationModel):
43
+ LOGGER.warning('WARNING ⚠️ YOLO SegmentationModel is not yet AutoShape compatible. '
44
+ 'You will not be able to run inference with this model.')
45
+ else:
46
+ model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
47
+ except Exception:
48
+ model = attempt_load(path, device=device, fuse=False) # arbitrary model
49
+ else:
50
+ cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
51
+ model = DetectionModel(cfg, channels, classes) # create model
52
+ if pretrained:
53
+ ckpt = torch.load(attempt_download(path), map_location=device) # load
54
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
55
+ csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
56
+ model.load_state_dict(csd, strict=False) # load
57
+ if len(ckpt['model'].names) == classes:
58
+ model.names = ckpt['model'].names # set class names attribute
59
+ if not verbose:
60
+ LOGGER.setLevel(logging.INFO) # reset to default
61
+ return model.to(device)
62
+
63
+ except Exception as e:
64
+ help_url = 'https://github.com/ultralytics/yolov5/issues/36'
65
+ s = f'{e}. Cache may be out of date, try `force_reload=True` or see {help_url} for help.'
66
+ raise Exception(s) from e
67
+
68
+
69
+ def custom(path='path/to/model.pt', autoshape=True, _verbose=True, device=None):
70
+ # YOLO custom or local model
71
+ return _create(path, autoshape=autoshape, verbose=_verbose, device=device)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ import argparse
76
+ from pathlib import Path
77
+
78
+ import numpy as np
79
+ from PIL import Image
80
+
81
+ from utils.general import cv2, print_args
82
+
83
+ # Argparser
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('--model', type=str, default='yolo', help='model name')
86
+ opt = parser.parse_args()
87
+ print_args(vars(opt))
88
+
89
+ # Model
90
+ model = _create(name=opt.model, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)
91
+ # model = custom(path='path/to/model.pt') # custom
92
+
93
+ # Images
94
+ imgs = [
95
+ 'data/images/zidane.jpg', # filename
96
+ Path('data/images/zidane.jpg'), # Path
97
+ 'https://ultralytics.com/images/zidane.jpg', # URI
98
+ cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
99
+ Image.open('data/images/bus.jpg'), # PIL
100
+ np.zeros((320, 640, 3))] # numpy
101
+
102
+ # Inference
103
+ results = model(imgs, size=320) # batched inference
104
+
105
+ # Results
106
+ results.print()
107
+ results.save()
yolov9/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
common.py → yolov9/models/common.py RENAMED
File without changes
yolov9/models/detect/gelan-c.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, ADown, [256]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, ADown, [512]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, ADown, [512]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, ADown, [256]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, ADown, [512]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # DDetect(P3, P4, P5)
80
+ ]
yolov9/models/detect/gelan-e.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [1024]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 9
44
+
45
+ # routing
46
+ [1, 1, CBLinear, [[64]]], # 10
47
+ [3, 1, CBLinear, [[64, 128]]], # 11
48
+ [5, 1, CBLinear, [[64, 128, 256]]], # 12
49
+ [7, 1, CBLinear, [[64, 128, 256, 512]]], # 13
50
+ [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]], # 14
51
+
52
+ # conv down fuse
53
+ [0, 1, Conv, [64, 3, 2]], # 15-P1/2
54
+ [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]], # 16
55
+
56
+ # conv down fuse
57
+ [-1, 1, Conv, [128, 3, 2]], # 17-P2/4
58
+ [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]], # 18
59
+
60
+ # elan-1 block
61
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 19
62
+
63
+ # avg-conv down fuse
64
+ [-1, 1, ADown, [256]], # 20-P3/8
65
+ [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]], # 21
66
+
67
+ # elan-2 block
68
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 22
69
+
70
+ # avg-conv down fuse
71
+ [-1, 1, ADown, [512]], # 23-P4/16
72
+ [[13, 14, -1], 1, CBFuse, [[3, 3]]], # 24
73
+
74
+ # elan-2 block
75
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 25
76
+
77
+ # avg-conv down fuse
78
+ [-1, 1, ADown, [1024]], # 26-P5/32
79
+ [[14, -1], 1, CBFuse, [[4]]], # 27
80
+
81
+ # elan-2 block
82
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 28
83
+ ]
84
+
85
+ # gelan head
86
+ head:
87
+ [
88
+ # elan-spp block
89
+ [28, 1, SPPELAN, [512, 256]], # 29
90
+
91
+ # up-concat merge
92
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
93
+ [[-1, 25], 1, Concat, [1]], # cat backbone P4
94
+
95
+ # elan-2 block
96
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
97
+
98
+ # up-concat merge
99
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
100
+ [[-1, 22], 1, Concat, [1]], # cat backbone P3
101
+
102
+ # elan-2 block
103
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35 (P3/8-small)
104
+
105
+ # avg-conv-down merge
106
+ [-1, 1, ADown, [256]],
107
+ [[-1, 32], 1, Concat, [1]], # cat head P4
108
+
109
+ # elan-2 block
110
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 38 (P4/16-medium)
111
+
112
+ # avg-conv-down merge
113
+ [-1, 1, ADown, [512]],
114
+ [[-1, 29], 1, Concat, [1]], # cat head P5
115
+
116
+ # elan-2 block
117
+ [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 41 (P5/32-large)
118
+
119
+ # detect
120
+ [[35, 38, 41], 1, DDetect, [nc]], # Detect(P3, P4, P5)
121
+ ]
yolov9/models/detect/gelan.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, Conv, [512, 3, 2]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, Conv, [256, 3, 2]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, Conv, [512, 3, 2]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # Detect(P3, P4, P5)
80
+ ]
yolov9/models/detect/yolov7-af.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv7
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1. # model depth multiple
6
+ width_multiple: 1. # layer channel multiple
7
+ anchors: 3
8
+
9
+ # YOLOv7 backbone
10
+ backbone:
11
+ # [from, number, module, args]
12
+ [[-1, 1, Conv, [32, 3, 1]], # 0
13
+
14
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
15
+ [-1, 1, Conv, [64, 3, 1]],
16
+
17
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
18
+ [-1, 1, Conv, [64, 1, 1]],
19
+ [-2, 1, Conv, [64, 1, 1]],
20
+ [-1, 1, Conv, [64, 3, 1]],
21
+ [-1, 1, Conv, [64, 3, 1]],
22
+ [-1, 1, Conv, [64, 3, 1]],
23
+ [-1, 1, Conv, [64, 3, 1]],
24
+ [[-1, -3, -5, -6], 1, Concat, [1]],
25
+ [-1, 1, Conv, [256, 1, 1]], # 11
26
+
27
+ [-1, 1, MP, []],
28
+ [-1, 1, Conv, [128, 1, 1]],
29
+ [-3, 1, Conv, [128, 1, 1]],
30
+ [-1, 1, Conv, [128, 3, 2]],
31
+ [[-1, -3], 1, Concat, [1]], # 16-P3/8
32
+ [-1, 1, Conv, [128, 1, 1]],
33
+ [-2, 1, Conv, [128, 1, 1]],
34
+ [-1, 1, Conv, [128, 3, 1]],
35
+ [-1, 1, Conv, [128, 3, 1]],
36
+ [-1, 1, Conv, [128, 3, 1]],
37
+ [-1, 1, Conv, [128, 3, 1]],
38
+ [[-1, -3, -5, -6], 1, Concat, [1]],
39
+ [-1, 1, Conv, [512, 1, 1]], # 24
40
+
41
+ [-1, 1, MP, []],
42
+ [-1, 1, Conv, [256, 1, 1]],
43
+ [-3, 1, Conv, [256, 1, 1]],
44
+ [-1, 1, Conv, [256, 3, 2]],
45
+ [[-1, -3], 1, Concat, [1]], # 29-P4/16
46
+ [-1, 1, Conv, [256, 1, 1]],
47
+ [-2, 1, Conv, [256, 1, 1]],
48
+ [-1, 1, Conv, [256, 3, 1]],
49
+ [-1, 1, Conv, [256, 3, 1]],
50
+ [-1, 1, Conv, [256, 3, 1]],
51
+ [-1, 1, Conv, [256, 3, 1]],
52
+ [[-1, -3, -5, -6], 1, Concat, [1]],
53
+ [-1, 1, Conv, [1024, 1, 1]], # 37
54
+
55
+ [-1, 1, MP, []],
56
+ [-1, 1, Conv, [512, 1, 1]],
57
+ [-3, 1, Conv, [512, 1, 1]],
58
+ [-1, 1, Conv, [512, 3, 2]],
59
+ [[-1, -3], 1, Concat, [1]], # 42-P5/32
60
+ [-1, 1, Conv, [256, 1, 1]],
61
+ [-2, 1, Conv, [256, 1, 1]],
62
+ [-1, 1, Conv, [256, 3, 1]],
63
+ [-1, 1, Conv, [256, 3, 1]],
64
+ [-1, 1, Conv, [256, 3, 1]],
65
+ [-1, 1, Conv, [256, 3, 1]],
66
+ [[-1, -3, -5, -6], 1, Concat, [1]],
67
+ [-1, 1, Conv, [1024, 1, 1]], # 50
68
+ ]
69
+
70
+ # yolov7 head
71
+ head:
72
+ [[-1, 1, SPPCSPC, [512]], # 51
73
+
74
+ [-1, 1, Conv, [256, 1, 1]],
75
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
76
+ [37, 1, Conv, [256, 1, 1]], # route backbone P4
77
+ [[-1, -2], 1, Concat, [1]],
78
+
79
+ [-1, 1, Conv, [256, 1, 1]],
80
+ [-2, 1, Conv, [256, 1, 1]],
81
+ [-1, 1, Conv, [128, 3, 1]],
82
+ [-1, 1, Conv, [128, 3, 1]],
83
+ [-1, 1, Conv, [128, 3, 1]],
84
+ [-1, 1, Conv, [128, 3, 1]],
85
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
86
+ [-1, 1, Conv, [256, 1, 1]], # 63
87
+
88
+ [-1, 1, Conv, [128, 1, 1]],
89
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
90
+ [24, 1, Conv, [128, 1, 1]], # route backbone P3
91
+ [[-1, -2], 1, Concat, [1]],
92
+
93
+ [-1, 1, Conv, [128, 1, 1]],
94
+ [-2, 1, Conv, [128, 1, 1]],
95
+ [-1, 1, Conv, [64, 3, 1]],
96
+ [-1, 1, Conv, [64, 3, 1]],
97
+ [-1, 1, Conv, [64, 3, 1]],
98
+ [-1, 1, Conv, [64, 3, 1]],
99
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
100
+ [-1, 1, Conv, [128, 1, 1]], # 75
101
+
102
+ [-1, 1, MP, []],
103
+ [-1, 1, Conv, [128, 1, 1]],
104
+ [-3, 1, Conv, [128, 1, 1]],
105
+ [-1, 1, Conv, [128, 3, 2]],
106
+ [[-1, -3, 63], 1, Concat, [1]],
107
+
108
+ [-1, 1, Conv, [256, 1, 1]],
109
+ [-2, 1, Conv, [256, 1, 1]],
110
+ [-1, 1, Conv, [128, 3, 1]],
111
+ [-1, 1, Conv, [128, 3, 1]],
112
+ [-1, 1, Conv, [128, 3, 1]],
113
+ [-1, 1, Conv, [128, 3, 1]],
114
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
115
+ [-1, 1, Conv, [256, 1, 1]], # 88
116
+
117
+ [-1, 1, MP, []],
118
+ [-1, 1, Conv, [256, 1, 1]],
119
+ [-3, 1, Conv, [256, 1, 1]],
120
+ [-1, 1, Conv, [256, 3, 2]],
121
+ [[-1, -3, 51], 1, Concat, [1]],
122
+
123
+ [-1, 1, Conv, [512, 1, 1]],
124
+ [-2, 1, Conv, [512, 1, 1]],
125
+ [-1, 1, Conv, [256, 3, 1]],
126
+ [-1, 1, Conv, [256, 3, 1]],
127
+ [-1, 1, Conv, [256, 3, 1]],
128
+ [-1, 1, Conv, [256, 3, 1]],
129
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
130
+ [-1, 1, Conv, [512, 1, 1]], # 101
131
+
132
+ [75, 1, Conv, [256, 3, 1]],
133
+ [88, 1, Conv, [512, 3, 1]],
134
+ [101, 1, Conv, [1024, 3, 1]],
135
+
136
+ [[102, 103, 104], 1, Detect, [nc]], # Detect(P3, P4, P5)
137
+ ]
yolov9/models/detect/yolov9-c.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [512]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
44
+ ]
45
+
46
+ # YOLOv9 head
47
+ head:
48
+ [
49
+ # elan-spp block
50
+ [-1, 1, SPPELAN, [512, 256]], # 10
51
+
52
+ # up-concat merge
53
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
54
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
55
+
56
+ # elan-2 block
57
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
58
+
59
+ # up-concat merge
60
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
61
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
62
+
63
+ # elan-2 block
64
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
65
+
66
+ # avg-conv-down merge
67
+ [-1, 1, ADown, [256]],
68
+ [[-1, 13], 1, Concat, [1]], # cat head P4
69
+
70
+ # elan-2 block
71
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
72
+
73
+ # avg-conv-down merge
74
+ [-1, 1, ADown, [512]],
75
+ [[-1, 10], 1, Concat, [1]], # cat head P5
76
+
77
+ # elan-2 block
78
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
79
+
80
+
81
+ # multi-level reversible auxiliary branch
82
+
83
+ # routing
84
+ [5, 1, CBLinear, [[256]]], # 23
85
+ [7, 1, CBLinear, [[256, 512]]], # 24
86
+ [9, 1, CBLinear, [[256, 512, 512]]], # 25
87
+
88
+ # conv down
89
+ [0, 1, Conv, [64, 3, 2]], # 26-P1/2
90
+
91
+ # conv down
92
+ [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
93
+
94
+ # elan-1 block
95
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
96
+
97
+ # avg-conv down fuse
98
+ [-1, 1, ADown, [256]], # 29-P3/8
99
+ [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
100
+
101
+ # elan-2 block
102
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
103
+
104
+ # avg-conv down fuse
105
+ [-1, 1, ADown, [512]], # 32-P4/16
106
+ [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
107
+
108
+ # elan-2 block
109
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
110
+
111
+ # avg-conv down fuse
112
+ [-1, 1, ADown, [512]], # 35-P5/32
113
+ [[25, -1], 1, CBFuse, [[2]]], # 36
114
+
115
+ # elan-2 block
116
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
117
+
118
+
119
+
120
+ # detection head
121
+
122
+ # detect
123
+ [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
124
+ ]
yolov9/models/detect/yolov9-e.yaml ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # csp-elan block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # csp-elan block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # csp-elan block
37
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [1024]], # 8-P5/32
41
+
42
+ # csp-elan block
43
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 9
44
+
45
+ # routing
46
+ [1, 1, CBLinear, [[64]]], # 10
47
+ [3, 1, CBLinear, [[64, 128]]], # 11
48
+ [5, 1, CBLinear, [[64, 128, 256]]], # 12
49
+ [7, 1, CBLinear, [[64, 128, 256, 512]]], # 13
50
+ [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]], # 14
51
+
52
+ # conv down
53
+ [0, 1, Conv, [64, 3, 2]], # 15-P1/2
54
+ [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]], # 16
55
+
56
+ # conv down
57
+ [-1, 1, Conv, [128, 3, 2]], # 17-P2/4
58
+ [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]], # 18
59
+
60
+ # csp-elan block
61
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 19
62
+
63
+ # avg-conv down fuse
64
+ [-1, 1, ADown, [256]], # 20-P3/8
65
+ [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]], # 21
66
+
67
+ # csp-elan block
68
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 22
69
+
70
+ # avg-conv down fuse
71
+ [-1, 1, ADown, [512]], # 23-P4/16
72
+ [[13, 14, -1], 1, CBFuse, [[3, 3]]], # 24
73
+
74
+ # csp-elan block
75
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 25
76
+
77
+ # avg-conv down fuse
78
+ [-1, 1, ADown, [1024]], # 26-P5/32
79
+ [[14, -1], 1, CBFuse, [[4]]], # 27
80
+
81
+ # csp-elan block
82
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 28
83
+ ]
84
+
85
+ # YOLOv9 head
86
+ head:
87
+ [
88
+ # multi-level auxiliary branch
89
+
90
+ # elan-spp block
91
+ [9, 1, SPPELAN, [512, 256]], # 29
92
+
93
+ # up-concat merge
94
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
95
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
96
+
97
+ # csp-elan block
98
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
99
+
100
+ # up-concat merge
101
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
102
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
103
+
104
+ # csp-elan block
105
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35
106
+
107
+
108
+
109
+ # main branch
110
+
111
+ # elan-spp block
112
+ [28, 1, SPPELAN, [512, 256]], # 36
113
+
114
+ # up-concat merge
115
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
116
+ [[-1, 25], 1, Concat, [1]], # cat backbone P4
117
+
118
+ # csp-elan block
119
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 39
120
+
121
+ # up-concat merge
122
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
123
+ [[-1, 22], 1, Concat, [1]], # cat backbone P3
124
+
125
+ # csp-elan block
126
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 42 (P3/8-small)
127
+
128
+ # avg-conv-down merge
129
+ [-1, 1, ADown, [256]],
130
+ [[-1, 39], 1, Concat, [1]], # cat head P4
131
+
132
+ # csp-elan block
133
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 45 (P4/16-medium)
134
+
135
+ # avg-conv-down merge
136
+ [-1, 1, ADown, [512]],
137
+ [[-1, 36], 1, Concat, [1]], # cat head P5
138
+
139
+ # csp-elan block
140
+ [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 48 (P5/32-large)
141
+
142
+ # detect
143
+ [[35, 32, 29, 42, 45, 48], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
144
+ ]
yolov9/models/detect/yolov9.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
26
+
27
+ # conv down
28
+ [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
32
+
33
+ # conv down
34
+ [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
38
+
39
+ # conv down
40
+ [-1, 1, Conv, [512, 3, 2]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
44
+ ]
45
+
46
+ # YOLOv9 head
47
+ head:
48
+ [
49
+ # elan-spp block
50
+ [-1, 1, SPPELAN, [512, 256]], # 10
51
+
52
+ # up-concat merge
53
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
54
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
55
+
56
+ # elan-2 block
57
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
58
+
59
+ # up-concat merge
60
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
61
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
62
+
63
+ # elan-2 block
64
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
65
+
66
+ # conv-down merge
67
+ [-1, 1, Conv, [256, 3, 2]],
68
+ [[-1, 13], 1, Concat, [1]], # cat head P4
69
+
70
+ # elan-2 block
71
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
72
+
73
+ # conv-down merge
74
+ [-1, 1, Conv, [512, 3, 2]],
75
+ [[-1, 10], 1, Concat, [1]], # cat head P5
76
+
77
+ # elan-2 block
78
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
79
+
80
+ # routing
81
+ [5, 1, CBLinear, [[256]]], # 23
82
+ [7, 1, CBLinear, [[256, 512]]], # 24
83
+ [9, 1, CBLinear, [[256, 512, 512]]], # 25
84
+
85
+ # conv down
86
+ [0, 1, Conv, [64, 3, 2]], # 26-P1/2
87
+
88
+ # conv down
89
+ [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
90
+
91
+ # elan-1 block
92
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
93
+
94
+ # conv down fuse
95
+ [-1, 1, Conv, [256, 3, 2]], # 29-P3/8
96
+ [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
97
+
98
+ # elan-2 block
99
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
100
+
101
+ # conv down fuse
102
+ [-1, 1, Conv, [512, 3, 2]], # 32-P4/16
103
+ [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
104
+
105
+ # elan-2 block
106
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
107
+
108
+ # conv down fuse
109
+ [-1, 1, Conv, [512, 3, 2]], # 35-P5/32
110
+ [[25, -1], 1, CBFuse, [[2]]], # 36
111
+
112
+ # elan-2 block
113
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
114
+
115
+ # detect
116
+ [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
117
+ ]
experimental.py → yolov9/models/experimental.py RENAMED
File without changes
yolov9/models/hub/anchors.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3 & YOLOv5
2
+ # Default anchors for COCO data
3
+
4
+
5
+ # P5 -------------------------------------------------------------------------------------------------------------------
6
+ # P5-640:
7
+ anchors_p5_640:
8
+ - [10,13, 16,30, 33,23] # P3/8
9
+ - [30,61, 62,45, 59,119] # P4/16
10
+ - [116,90, 156,198, 373,326] # P5/32
11
+
12
+
13
+ # P6 -------------------------------------------------------------------------------------------------------------------
14
+ # P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387
15
+ anchors_p6_640:
16
+ - [9,11, 21,19, 17,41] # P3/8
17
+ - [43,32, 39,70, 86,64] # P4/16
18
+ - [65,131, 134,130, 120,265] # P5/32
19
+ - [282,180, 247,354, 512,387] # P6/64
20
+
21
+ # P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792
22
+ anchors_p6_1280:
23
+ - [19,27, 44,40, 38,94] # P3/8
24
+ - [96,68, 86,152, 180,137] # P4/16
25
+ - [140,301, 303,264, 238,542] # P5/32
26
+ - [436,615, 739,380, 925,792] # P6/64
27
+
28
+ # P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187
29
+ anchors_p6_1920:
30
+ - [28,41, 67,59, 57,141] # P3/8
31
+ - [144,103, 129,227, 270,205] # P4/16
32
+ - [209,452, 455,396, 358,812] # P5/32
33
+ - [653,922, 1109,570, 1387,1187] # P6/64
34
+
35
+
36
+ # P7 -------------------------------------------------------------------------------------------------------------------
37
+ # P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372
38
+ anchors_p7_640:
39
+ - [11,11, 13,30, 29,20] # P3/8
40
+ - [30,46, 61,38, 39,92] # P4/16
41
+ - [78,80, 146,66, 79,163] # P5/32
42
+ - [149,150, 321,143, 157,303] # P6/64
43
+ - [257,402, 359,290, 524,372] # P7/128
44
+
45
+ # P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818
46
+ anchors_p7_1280:
47
+ - [19,22, 54,36, 32,77] # P3/8
48
+ - [70,83, 138,71, 75,173] # P4/16
49
+ - [165,159, 148,334, 375,151] # P5/32
50
+ - [334,317, 251,626, 499,474] # P6/64
51
+ - [750,326, 534,814, 1079,818] # P7/128
52
+
53
+ # P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227
54
+ anchors_p7_1920:
55
+ - [29,34, 81,55, 47,115] # P3/8
56
+ - [105,124, 207,107, 113,259] # P4/16
57
+ - [247,238, 222,500, 563,227] # P5/32
58
+ - [501,476, 376,939, 749,711] # P6/64
59
+ - [1126,489, 801,1222, 1618,1227] # P7/128
yolov9/models/hub/yolov3-spp.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors:
8
+ - [10,13, 16,30, 33,23] # P3/8
9
+ - [30,61, 62,45, 59,119] # P4/16
10
+ - [116,90, 156,198, 373,326] # P5/32
11
+
12
+ # darknet53 backbone
13
+ backbone:
14
+ # [from, number, module, args]
15
+ [[-1, 1, Conv, [32, 3, 1]], # 0
16
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
17
+ [-1, 1, Bottleneck, [64]],
18
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
19
+ [-1, 2, Bottleneck, [128]],
20
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
21
+ [-1, 8, Bottleneck, [256]],
22
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
23
+ [-1, 8, Bottleneck, [512]],
24
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
25
+ [-1, 4, Bottleneck, [1024]], # 10
26
+ ]
27
+
28
+ # YOLOv3-SPP head
29
+ head:
30
+ [[-1, 1, Bottleneck, [1024, False]],
31
+ [-1, 1, SPP, [512, [5, 9, 13]]],
32
+ [-1, 1, Conv, [1024, 3, 1]],
33
+ [-1, 1, Conv, [512, 1, 1]],
34
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
35
+
36
+ [-2, 1, Conv, [256, 1, 1]],
37
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
39
+ [-1, 1, Bottleneck, [512, False]],
40
+ [-1, 1, Bottleneck, [512, False]],
41
+ [-1, 1, Conv, [256, 1, 1]],
42
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
43
+
44
+ [-2, 1, Conv, [128, 1, 1]],
45
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
46
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
47
+ [-1, 1, Bottleneck, [256, False]],
48
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
49
+
50
+ [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
51
+ ]
yolov9/models/hub/yolov3-tiny.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors:
8
+ - [10,14, 23,27, 37,58] # P4/16
9
+ - [81,82, 135,169, 344,319] # P5/32
10
+
11
+ # YOLOv3-tiny backbone
12
+ backbone:
13
+ # [from, number, module, args]
14
+ [[-1, 1, Conv, [16, 3, 1]], # 0
15
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
16
+ [-1, 1, Conv, [32, 3, 1]],
17
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
18
+ [-1, 1, Conv, [64, 3, 1]],
19
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
20
+ [-1, 1, Conv, [128, 3, 1]],
21
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
22
+ [-1, 1, Conv, [256, 3, 1]],
23
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
24
+ [-1, 1, Conv, [512, 3, 1]],
25
+ [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
26
+ [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
27
+ ]
28
+
29
+ # YOLOv3-tiny head
30
+ head:
31
+ [[-1, 1, Conv, [1024, 3, 1]],
32
+ [-1, 1, Conv, [256, 1, 1]],
33
+ [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
34
+
35
+ [-2, 1, Conv, [128, 1, 1]],
36
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
37
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
38
+ [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
39
+
40
+ [[19, 15], 1, Detect, [nc, anchors]], # Detect(P4, P5)
41
+ ]
yolov9/models/hub/yolov3.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors:
8
+ - [10,13, 16,30, 33,23] # P3/8
9
+ - [30,61, 62,45, 59,119] # P4/16
10
+ - [116,90, 156,198, 373,326] # P5/32
11
+
12
+ # darknet53 backbone
13
+ backbone:
14
+ # [from, number, module, args]
15
+ [[-1, 1, Conv, [32, 3, 1]], # 0
16
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
17
+ [-1, 1, Bottleneck, [64]],
18
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
19
+ [-1, 2, Bottleneck, [128]],
20
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
21
+ [-1, 8, Bottleneck, [256]],
22
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
23
+ [-1, 8, Bottleneck, [512]],
24
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
25
+ [-1, 4, Bottleneck, [1024]], # 10
26
+ ]
27
+
28
+ # YOLOv3 head
29
+ head:
30
+ [[-1, 1, Bottleneck, [1024, False]],
31
+ [-1, 1, Conv, [512, 1, 1]],
32
+ [-1, 1, Conv, [1024, 3, 1]],
33
+ [-1, 1, Conv, [512, 1, 1]],
34
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
35
+
36
+ [-2, 1, Conv, [256, 1, 1]],
37
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
39
+ [-1, 1, Bottleneck, [512, False]],
40
+ [-1, 1, Bottleneck, [512, False]],
41
+ [-1, 1, Conv, [256, 1, 1]],
42
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
43
+
44
+ [-2, 1, Conv, [128, 1, 1]],
45
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
46
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
47
+ [-1, 1, Bottleneck, [256, False]],
48
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
49
+
50
+ [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
51
+ ]
yolov9/models/panoptic/yolov7-af-pan.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv7
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ sem_nc: 93 # number of stuff classes
6
+ depth_multiple: 1.0 # model depth multiple
7
+ width_multiple: 1.0 # layer channel multiple
8
+ anchors: 3
9
+
10
+ # YOLOv7 backbone
11
+ backbone:
12
+ [[-1, 1, Conv, [32, 3, 1]], # 0
13
+
14
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
15
+ [-1, 1, Conv, [64, 3, 1]],
16
+
17
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
18
+ [-1, 1, Conv, [64, 1, 1]],
19
+ [-2, 1, Conv, [64, 1, 1]],
20
+ [-1, 1, Conv, [64, 3, 1]],
21
+ [-1, 1, Conv, [64, 3, 1]],
22
+ [-1, 1, Conv, [64, 3, 1]],
23
+ [-1, 1, Conv, [64, 3, 1]],
24
+ [[-1, -3, -5, -6], 1, Concat, [1]],
25
+ [-1, 1, Conv, [256, 1, 1]], # 11
26
+
27
+ [-1, 1, MP, []],
28
+ [-1, 1, Conv, [128, 1, 1]],
29
+ [-3, 1, Conv, [128, 1, 1]],
30
+ [-1, 1, Conv, [128, 3, 2]],
31
+ [[-1, -3], 1, Concat, [1]], # 16-P3/8
32
+ [-1, 1, Conv, [128, 1, 1]],
33
+ [-2, 1, Conv, [128, 1, 1]],
34
+ [-1, 1, Conv, [128, 3, 1]],
35
+ [-1, 1, Conv, [128, 3, 1]],
36
+ [-1, 1, Conv, [128, 3, 1]],
37
+ [-1, 1, Conv, [128, 3, 1]],
38
+ [[-1, -3, -5, -6], 1, Concat, [1]],
39
+ [-1, 1, Conv, [512, 1, 1]], # 24
40
+
41
+ [-1, 1, MP, []],
42
+ [-1, 1, Conv, [256, 1, 1]],
43
+ [-3, 1, Conv, [256, 1, 1]],
44
+ [-1, 1, Conv, [256, 3, 2]],
45
+ [[-1, -3], 1, Concat, [1]], # 29-P4/16
46
+ [-1, 1, Conv, [256, 1, 1]],
47
+ [-2, 1, Conv, [256, 1, 1]],
48
+ [-1, 1, Conv, [256, 3, 1]],
49
+ [-1, 1, Conv, [256, 3, 1]],
50
+ [-1, 1, Conv, [256, 3, 1]],
51
+ [-1, 1, Conv, [256, 3, 1]],
52
+ [[-1, -3, -5, -6], 1, Concat, [1]],
53
+ [-1, 1, Conv, [1024, 1, 1]], # 37
54
+
55
+ [-1, 1, MP, []],
56
+ [-1, 1, Conv, [512, 1, 1]],
57
+ [-3, 1, Conv, [512, 1, 1]],
58
+ [-1, 1, Conv, [512, 3, 2]],
59
+ [[-1, -3], 1, Concat, [1]], # 42-P5/32
60
+ [-1, 1, Conv, [256, 1, 1]],
61
+ [-2, 1, Conv, [256, 1, 1]],
62
+ [-1, 1, Conv, [256, 3, 1]],
63
+ [-1, 1, Conv, [256, 3, 1]],
64
+ [-1, 1, Conv, [256, 3, 1]],
65
+ [-1, 1, Conv, [256, 3, 1]],
66
+ [[-1, -3, -5, -6], 1, Concat, [1]],
67
+ [-1, 1, Conv, [1024, 1, 1]], # 50
68
+ ]
69
+
70
+ # yolov7 head
71
+ head:
72
+ [[-1, 1, SPPCSPC, [512]], # 51
73
+
74
+ [-1, 1, Conv, [256, 1, 1]],
75
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
76
+ [37, 1, Conv, [256, 1, 1]], # route backbone P4
77
+ [[-1, -2], 1, Concat, [1]],
78
+
79
+ [-1, 1, Conv, [256, 1, 1]],
80
+ [-2, 1, Conv, [256, 1, 1]],
81
+ [-1, 1, Conv, [128, 3, 1]],
82
+ [-1, 1, Conv, [128, 3, 1]],
83
+ [-1, 1, Conv, [128, 3, 1]],
84
+ [-1, 1, Conv, [128, 3, 1]],
85
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
86
+ [-1, 1, Conv, [256, 1, 1]], # 63
87
+
88
+ [-1, 1, Conv, [128, 1, 1]],
89
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
90
+ [24, 1, Conv, [128, 1, 1]], # route backbone P3
91
+ [[-1, -2], 1, Concat, [1]],
92
+
93
+ [-1, 1, Conv, [128, 1, 1]],
94
+ [-2, 1, Conv, [128, 1, 1]],
95
+ [-1, 1, Conv, [64, 3, 1]],
96
+ [-1, 1, Conv, [64, 3, 1]],
97
+ [-1, 1, Conv, [64, 3, 1]],
98
+ [-1, 1, Conv, [64, 3, 1]],
99
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
100
+ [-1, 1, Conv, [128, 1, 1]], # 75
101
+
102
+ [-1, 1, MP, []],
103
+ [-1, 1, Conv, [128, 1, 1]],
104
+ [-3, 1, Conv, [128, 1, 1]],
105
+ [-1, 1, Conv, [128, 3, 2]],
106
+ [[-1, -3, 63], 1, Concat, [1]],
107
+
108
+ [-1, 1, Conv, [256, 1, 1]],
109
+ [-2, 1, Conv, [256, 1, 1]],
110
+ [-1, 1, Conv, [128, 3, 1]],
111
+ [-1, 1, Conv, [128, 3, 1]],
112
+ [-1, 1, Conv, [128, 3, 1]],
113
+ [-1, 1, Conv, [128, 3, 1]],
114
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
115
+ [-1, 1, Conv, [256, 1, 1]], # 88
116
+
117
+ [-1, 1, MP, []],
118
+ [-1, 1, Conv, [256, 1, 1]],
119
+ [-3, 1, Conv, [256, 1, 1]],
120
+ [-1, 1, Conv, [256, 3, 2]],
121
+ [[-1, -3, 51], 1, Concat, [1]],
122
+
123
+ [-1, 1, Conv, [512, 1, 1]],
124
+ [-2, 1, Conv, [512, 1, 1]],
125
+ [-1, 1, Conv, [256, 3, 1]],
126
+ [-1, 1, Conv, [256, 3, 1]],
127
+ [-1, 1, Conv, [256, 3, 1]],
128
+ [-1, 1, Conv, [256, 3, 1]],
129
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
130
+ [-1, 1, Conv, [512, 1, 1]], # 101
131
+
132
+ [75, 1, Conv, [256, 3, 1]],
133
+ [88, 1, Conv, [512, 3, 1]],
134
+ [101, 1, Conv, [1024, 3, 1]],
135
+
136
+ [[102, 103, 104], 1, Panoptic, [nc, 93, 32, 256]], # Panoptic(P3, P4, P5)
137
+ ]
yolov9/models/segment/yolov7-af-seg.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv7
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors: 3
8
+
9
+ # YOLOv7 backbone
10
+ backbone:
11
+ [[-1, 1, Conv, [32, 3, 1]], # 0
12
+
13
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14
+ [-1, 1, Conv, [64, 3, 1]],
15
+
16
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
17
+ [-1, 1, Conv, [64, 1, 1]],
18
+ [-2, 1, Conv, [64, 1, 1]],
19
+ [-1, 1, Conv, [64, 3, 1]],
20
+ [-1, 1, Conv, [64, 3, 1]],
21
+ [-1, 1, Conv, [64, 3, 1]],
22
+ [-1, 1, Conv, [64, 3, 1]],
23
+ [[-1, -3, -5, -6], 1, Concat, [1]],
24
+ [-1, 1, Conv, [256, 1, 1]], # 11
25
+
26
+ [-1, 1, MP, []],
27
+ [-1, 1, Conv, [128, 1, 1]],
28
+ [-3, 1, Conv, [128, 1, 1]],
29
+ [-1, 1, Conv, [128, 3, 2]],
30
+ [[-1, -3], 1, Concat, [1]], # 16-P3/8
31
+ [-1, 1, Conv, [128, 1, 1]],
32
+ [-2, 1, Conv, [128, 1, 1]],
33
+ [-1, 1, Conv, [128, 3, 1]],
34
+ [-1, 1, Conv, [128, 3, 1]],
35
+ [-1, 1, Conv, [128, 3, 1]],
36
+ [-1, 1, Conv, [128, 3, 1]],
37
+ [[-1, -3, -5, -6], 1, Concat, [1]],
38
+ [-1, 1, Conv, [512, 1, 1]], # 24
39
+
40
+ [-1, 1, MP, []],
41
+ [-1, 1, Conv, [256, 1, 1]],
42
+ [-3, 1, Conv, [256, 1, 1]],
43
+ [-1, 1, Conv, [256, 3, 2]],
44
+ [[-1, -3], 1, Concat, [1]], # 29-P4/16
45
+ [-1, 1, Conv, [256, 1, 1]],
46
+ [-2, 1, Conv, [256, 1, 1]],
47
+ [-1, 1, Conv, [256, 3, 1]],
48
+ [-1, 1, Conv, [256, 3, 1]],
49
+ [-1, 1, Conv, [256, 3, 1]],
50
+ [-1, 1, Conv, [256, 3, 1]],
51
+ [[-1, -3, -5, -6], 1, Concat, [1]],
52
+ [-1, 1, Conv, [1024, 1, 1]], # 37
53
+
54
+ [-1, 1, MP, []],
55
+ [-1, 1, Conv, [512, 1, 1]],
56
+ [-3, 1, Conv, [512, 1, 1]],
57
+ [-1, 1, Conv, [512, 3, 2]],
58
+ [[-1, -3], 1, Concat, [1]], # 42-P5/32
59
+ [-1, 1, Conv, [256, 1, 1]],
60
+ [-2, 1, Conv, [256, 1, 1]],
61
+ [-1, 1, Conv, [256, 3, 1]],
62
+ [-1, 1, Conv, [256, 3, 1]],
63
+ [-1, 1, Conv, [256, 3, 1]],
64
+ [-1, 1, Conv, [256, 3, 1]],
65
+ [[-1, -3, -5, -6], 1, Concat, [1]],
66
+ [-1, 1, Conv, [1024, 1, 1]], # 50
67
+ ]
68
+
69
+ # yolov7 head
70
+ head:
71
+ [[-1, 1, SPPCSPC, [512]], # 51
72
+
73
+ [-1, 1, Conv, [256, 1, 1]],
74
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
75
+ [37, 1, Conv, [256, 1, 1]], # route backbone P4
76
+ [[-1, -2], 1, Concat, [1]],
77
+
78
+ [-1, 1, Conv, [256, 1, 1]],
79
+ [-2, 1, Conv, [256, 1, 1]],
80
+ [-1, 1, Conv, [128, 3, 1]],
81
+ [-1, 1, Conv, [128, 3, 1]],
82
+ [-1, 1, Conv, [128, 3, 1]],
83
+ [-1, 1, Conv, [128, 3, 1]],
84
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
85
+ [-1, 1, Conv, [256, 1, 1]], # 63
86
+
87
+ [-1, 1, Conv, [128, 1, 1]],
88
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
89
+ [24, 1, Conv, [128, 1, 1]], # route backbone P3
90
+ [[-1, -2], 1, Concat, [1]],
91
+
92
+ [-1, 1, Conv, [128, 1, 1]],
93
+ [-2, 1, Conv, [128, 1, 1]],
94
+ [-1, 1, Conv, [64, 3, 1]],
95
+ [-1, 1, Conv, [64, 3, 1]],
96
+ [-1, 1, Conv, [64, 3, 1]],
97
+ [-1, 1, Conv, [64, 3, 1]],
98
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
99
+ [-1, 1, Conv, [128, 1, 1]], # 75
100
+
101
+ [-1, 1, MP, []],
102
+ [-1, 1, Conv, [128, 1, 1]],
103
+ [-3, 1, Conv, [128, 1, 1]],
104
+ [-1, 1, Conv, [128, 3, 2]],
105
+ [[-1, -3, 63], 1, Concat, [1]],
106
+
107
+ [-1, 1, Conv, [256, 1, 1]],
108
+ [-2, 1, Conv, [256, 1, 1]],
109
+ [-1, 1, Conv, [128, 3, 1]],
110
+ [-1, 1, Conv, [128, 3, 1]],
111
+ [-1, 1, Conv, [128, 3, 1]],
112
+ [-1, 1, Conv, [128, 3, 1]],
113
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
114
+ [-1, 1, Conv, [256, 1, 1]], # 88
115
+
116
+ [-1, 1, MP, []],
117
+ [-1, 1, Conv, [256, 1, 1]],
118
+ [-3, 1, Conv, [256, 1, 1]],
119
+ [-1, 1, Conv, [256, 3, 2]],
120
+ [[-1, -3, 51], 1, Concat, [1]],
121
+
122
+ [-1, 1, Conv, [512, 1, 1]],
123
+ [-2, 1, Conv, [512, 1, 1]],
124
+ [-1, 1, Conv, [256, 3, 1]],
125
+ [-1, 1, Conv, [256, 3, 1]],
126
+ [-1, 1, Conv, [256, 3, 1]],
127
+ [-1, 1, Conv, [256, 3, 1]],
128
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
129
+ [-1, 1, Conv, [512, 1, 1]], # 101
130
+
131
+ [75, 1, Conv, [256, 3, 1]],
132
+ [88, 1, Conv, [512, 3, 1]],
133
+ [101, 1, Conv, [1024, 3, 1]],
134
+
135
+ [[102, 103, 104], 1, Segment, [nc, 32, 256]], # Segment(P3, P4, P5)
136
+ ]
tf.py → yolov9/models/tf.py RENAMED
File without changes
yolo.py → yolov9/models/yolo.py RENAMED
@@ -12,13 +12,13 @@ if str(ROOT) not in sys.path:
12
  if platform.system() != 'Windows':
13
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
 
15
- from common import *
16
- from experimental import *
17
- from general import LOGGER, check_version, check_yaml, make_divisible, print_args
18
- from plots import feature_visualization
19
- from torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
20
  time_sync)
21
- from anchor_generator import make_anchors, dist2bbox
22
 
23
  try:
24
  import thop # for FLOPs computation
 
12
  if platform.system() != 'Windows':
13
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
 
15
+ from models.common import *
16
+ from models.experimental import *
17
+ from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
18
+ from utils.plots import feature_visualization
19
+ from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
20
  time_sync)
21
+ from utils.tal.anchor_generator import make_anchors, dist2bbox
22
 
23
  try:
24
  import thop # for FLOPs computation
yolov9/panoptic/predict.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[1] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
19
+ strip_optimizer, xyxy2xywh)
20
+ from utils.plots import Annotator, colors, save_one_box
21
+ from utils.segment.general import masks2segments, process_mask
22
+ from utils.torch_utils import select_device, smart_inference_mode
23
+
24
+
25
+ @smart_inference_mode()
26
+ def run(
27
+ weights=ROOT / 'yolo-pan.pt', # model.pt path(s)
28
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
29
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
30
+ imgsz=(640, 640), # inference size (height, width)
31
+ conf_thres=0.25, # confidence threshold
32
+ iou_thres=0.45, # NMS IOU threshold
33
+ max_det=1000, # maximum detections per image
34
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
35
+ view_img=False, # show results
36
+ save_txt=False, # save results to *.txt
37
+ save_conf=False, # save confidences in --save-txt labels
38
+ save_crop=False, # save cropped prediction boxes
39
+ nosave=False, # do not save images/videos
40
+ classes=None, # filter by class: --class 0, or --class 0 2 3
41
+ agnostic_nms=False, # class-agnostic NMS
42
+ augment=False, # augmented inference
43
+ visualize=False, # visualize features
44
+ update=False, # update all models
45
+ project=ROOT / 'runs/predict-seg', # save results to project/name
46
+ name='exp', # save results to project/name
47
+ exist_ok=False, # existing project/name ok, do not increment
48
+ line_thickness=3, # bounding box thickness (pixels)
49
+ hide_labels=False, # hide labels
50
+ hide_conf=False, # hide confidences
51
+ half=False, # use FP16 half-precision inference
52
+ dnn=False, # use OpenCV DNN for ONNX inference
53
+ vid_stride=1, # video frame-rate stride
54
+ retina_masks=False,
55
+ ):
56
+ source = str(source)
57
+ save_img = not nosave and not source.endswith('.txt') # save inference images
58
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
59
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
60
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
61
+ screenshot = source.lower().startswith('screen')
62
+ if is_url and is_file:
63
+ source = check_file(source) # download
64
+
65
+ # Directories
66
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
67
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
68
+
69
+ # Load model
70
+ device = select_device(device)
71
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
72
+ stride, names, pt = model.stride, model.names, model.pt
73
+ imgsz = check_img_size(imgsz, s=stride) # check image size
74
+
75
+ # Dataloader
76
+ bs = 1 # batch_size
77
+ if webcam:
78
+ view_img = check_imshow(warn=True)
79
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
80
+ bs = len(dataset)
81
+ elif screenshot:
82
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
83
+ else:
84
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
85
+ vid_path, vid_writer = [None] * bs, [None] * bs
86
+
87
+ # Run inference
88
+ model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
89
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
90
+ for path, im, im0s, vid_cap, s in dataset:
91
+ with dt[0]:
92
+ im = torch.from_numpy(im).to(model.device)
93
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
94
+ im /= 255 # 0 - 255 to 0.0 - 1.0
95
+ if len(im.shape) == 3:
96
+ im = im[None] # expand for batch dim
97
+
98
+ # Inference
99
+ with dt[1]:
100
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
101
+ pred, proto = model(im, augment=augment, visualize=visualize)[:2]
102
+
103
+ # NMS
104
+ with dt[2]:
105
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det, nm=32)
106
+
107
+ # Second-stage classifier (optional)
108
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
109
+
110
+ # Process predictions
111
+ for i, det in enumerate(pred): # per image
112
+ seen += 1
113
+ if webcam: # batch_size >= 1
114
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
115
+ s += f'{i}: '
116
+ else:
117
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
118
+
119
+ p = Path(p) # to Path
120
+ save_path = str(save_dir / p.name) # im.jpg
121
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
122
+ s += '%gx%g ' % im.shape[2:] # print string
123
+ imc = im0.copy() if save_crop else im0 # for save_crop
124
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
125
+ if len(det):
126
+ masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
127
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
128
+
129
+ # Segments
130
+ if save_txt:
131
+ segments = reversed(masks2segments(masks))
132
+ segments = [scale_segments(im.shape[2:], x, im0.shape, normalize=True) for x in segments]
133
+
134
+ # Print results
135
+ for c in det[:, 5].unique():
136
+ n = (det[:, 5] == c).sum() # detections per class
137
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
138
+
139
+ # Mask plotting
140
+ annotator.masks(masks,
141
+ colors=[colors(x, True) for x in det[:, 5]],
142
+ im_gpu=None if retina_masks else im[i])
143
+
144
+ # Write results
145
+ for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
146
+ if save_txt: # Write to file
147
+ segj = segments[j].reshape(-1) # (n,2) to (n*2)
148
+ line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
149
+ with open(f'{txt_path}.txt', 'a') as f:
150
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
151
+
152
+ if save_img or save_crop or view_img: # Add bbox to image
153
+ c = int(cls) # integer class
154
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
155
+ annotator.box_label(xyxy, label, color=colors(c, True))
156
+ # annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
157
+ if save_crop:
158
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
159
+
160
+ # Stream results
161
+ im0 = annotator.result()
162
+ if view_img:
163
+ if platform.system() == 'Linux' and p not in windows:
164
+ windows.append(p)
165
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
166
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
167
+ cv2.imshow(str(p), im0)
168
+ if cv2.waitKey(1) == ord('q'): # 1 millisecond
169
+ exit()
170
+
171
+ # Save results (image with detections)
172
+ if save_img:
173
+ if dataset.mode == 'image':
174
+ cv2.imwrite(save_path, im0)
175
+ else: # 'video' or 'stream'
176
+ if vid_path[i] != save_path: # new video
177
+ vid_path[i] = save_path
178
+ if isinstance(vid_writer[i], cv2.VideoWriter):
179
+ vid_writer[i].release() # release previous video writer
180
+ if vid_cap: # video
181
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
182
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
183
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
184
+ else: # stream
185
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
186
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
187
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
188
+ vid_writer[i].write(im0)
189
+
190
+ # Print time (inference-only)
191
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
192
+
193
+ # Print results
194
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
195
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
196
+ if save_txt or save_img:
197
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
198
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
199
+ if update:
200
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
201
+
202
+
203
+ def parse_opt():
204
+ parser = argparse.ArgumentParser()
205
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo-pan.pt', help='model path(s)')
206
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
207
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
208
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
209
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
210
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
211
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
212
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
213
+ parser.add_argument('--view-img', action='store_true', help='show results')
214
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
215
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
216
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
217
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
218
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
219
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
220
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
221
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
222
+ parser.add_argument('--update', action='store_true', help='update all models')
223
+ parser.add_argument('--project', default=ROOT / 'runs/predict-seg', help='save results to project/name')
224
+ parser.add_argument('--name', default='exp', help='save results to project/name')
225
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
226
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
227
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
228
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
229
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
230
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
231
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
232
+ parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution')
233
+ opt = parser.parse_args()
234
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
235
+ print_args(vars(opt))
236
+ return opt
237
+
238
+
239
+ def main(opt):
240
+ check_requirements(exclude=('tensorboard', 'thop'))
241
+ run(**vars(opt))
242
+
243
+
244
+ if __name__ == "__main__":
245
+ opt = parse_opt()
246
+ main(opt)
yolov9/panoptic/train.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import yaml
16
+ from torch.optim import lr_scheduler
17
+ from tqdm import tqdm
18
+
19
+ FILE = Path(__file__).resolve()
20
+ ROOT = FILE.parents[1] # YOLO root directory
21
+ if str(ROOT) not in sys.path:
22
+ sys.path.append(str(ROOT)) # add ROOT to PATH
23
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
24
+
25
+ import panoptic.val as validate # for end-of-epoch mAP
26
+ from models.experimental import attempt_load
27
+ from models.yolo import SegmentationModel
28
+ from utils.autoanchor import check_anchors
29
+ from utils.autobatch import check_train_batch_size
30
+ from utils.callbacks import Callbacks
31
+ from utils.downloads import attempt_download, is_url
32
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
33
+ check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
34
+ get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
35
+ labels_to_image_weights, one_cycle, one_flat_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
36
+ from utils.loggers import GenericLogger
37
+ from utils.plots import plot_evolve, plot_labels
38
+ from utils.panoptic.dataloaders import create_dataloader
39
+ from utils.panoptic.loss_tal import ComputeLoss
40
+ from utils.panoptic.metrics import KEYS, fitness
41
+ from utils.panoptic.plots import plot_images_and_masks, plot_results_with_masks
42
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
43
+ smart_resume, torch_distributed_zero_first)
44
+
45
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
46
+ RANK = int(os.getenv('RANK', -1))
47
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
48
+ GIT_INFO = None#check_git_info()
49
+
50
+
51
+ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
52
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, mask_ratio = \
53
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
54
+ opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.mask_ratio
55
+ # callbacks.run('on_pretrain_routine_start')
56
+
57
+ # Directories
58
+ w = save_dir / 'weights' # weights dir
59
+ (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
60
+ last, best = w / 'last.pt', w / 'best.pt'
61
+
62
+ # Hyperparameters
63
+ if isinstance(hyp, str):
64
+ with open(hyp, errors='ignore') as f:
65
+ hyp = yaml.safe_load(f) # load hyps dict
66
+ LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
67
+ opt.hyp = hyp.copy() # for saving hyps to checkpoints
68
+
69
+ # Save run settings
70
+ if not evolve:
71
+ yaml_save(save_dir / 'hyp.yaml', hyp)
72
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
73
+
74
+ # Loggers
75
+ data_dict = None
76
+ if RANK in {-1, 0}:
77
+ logger = GenericLogger(opt=opt, console_logger=LOGGER)
78
+
79
+ # Config
80
+ plots = not evolve and not opt.noplots # create plots
81
+ overlap = not opt.no_overlap
82
+ cuda = device.type != 'cpu'
83
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
84
+ with torch_distributed_zero_first(LOCAL_RANK):
85
+ data_dict = data_dict or check_dataset(data) # check if None
86
+ train_path, val_path = data_dict['train'], data_dict['val']
87
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
88
+ names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
89
+ #is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
90
+ is_coco = isinstance(val_path, str) and val_path.endswith('val2017.txt') # COCO dataset
91
+
92
+ # Model
93
+ check_suffix(weights, '.pt') # check weights
94
+ pretrained = weights.endswith('.pt')
95
+ if pretrained:
96
+ with torch_distributed_zero_first(LOCAL_RANK):
97
+ weights = attempt_download(weights) # download if not found locally
98
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
99
+ model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device)
100
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
101
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
102
+ csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
103
+ model.load_state_dict(csd, strict=False) # load
104
+ LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
105
+ else:
106
+ model = SegmentationModel(cfg, ch=3, nc=nc).to(device) # create
107
+ amp = check_amp(model) # check AMP
108
+
109
+ # Freeze
110
+ freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
111
+ for k, v in model.named_parameters():
112
+ #v.requires_grad = True # train all layers
113
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
114
+ if any(x in k for x in freeze):
115
+ LOGGER.info(f'freezing {k}')
116
+ v.requires_grad = False
117
+
118
+ # Image size
119
+ gs = max(int(model.stride.max()), 32) # grid size (max stride)
120
+ imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
121
+
122
+ # Batch size
123
+ if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
124
+ batch_size = check_train_batch_size(model, imgsz, amp)
125
+ logger.update_params({"batch_size": batch_size})
126
+ # loggers.on_params_update({"batch_size": batch_size})
127
+
128
+ # Optimizer
129
+ nbs = 64 # nominal batch size
130
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
131
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
132
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
133
+
134
+ # Scheduler
135
+ if opt.cos_lr:
136
+ lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
137
+ elif opt.flat_cos_lr:
138
+ lf = one_flat_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
139
+ elif opt.fixed_lr:
140
+ lf = lambda x: 1.0
141
+ elif opt.poly_lr:
142
+ power = 0.9
143
+ lf = lambda x: ((1 - (x / epochs)) ** power) * (1.0 - hyp['lrf']) + hyp['lrf']
144
+ else:
145
+ lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
146
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
147
+
148
+ # EMA
149
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
150
+
151
+ # Resume
152
+ best_fitness, start_epoch = 0.0, 0
153
+ if pretrained:
154
+ if resume:
155
+ best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
156
+ del ckpt, csd
157
+
158
+ # DP mode
159
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
160
+ LOGGER.warning('WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.')
161
+ model = torch.nn.DataParallel(model)
162
+
163
+ # SyncBatchNorm
164
+ if opt.sync_bn and cuda and RANK != -1:
165
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
166
+ LOGGER.info('Using SyncBatchNorm()')
167
+
168
+ # Trainloader
169
+ train_loader, dataset = create_dataloader(
170
+ train_path,
171
+ imgsz,
172
+ batch_size // WORLD_SIZE,
173
+ gs,
174
+ single_cls,
175
+ hyp=hyp,
176
+ augment=True,
177
+ cache=None if opt.cache == 'val' else opt.cache,
178
+ rect=opt.rect,
179
+ rank=LOCAL_RANK,
180
+ workers=workers,
181
+ image_weights=opt.image_weights,
182
+ close_mosaic=opt.close_mosaic != 0,
183
+ quad=opt.quad,
184
+ prefix=colorstr('train: '),
185
+ shuffle=True,
186
+ mask_downsample_ratio=mask_ratio,
187
+ overlap_mask=overlap,
188
+ )
189
+ labels = np.concatenate(dataset.labels, 0)
190
+ mlc = int(labels[:, 0].max()) # max label class
191
+ assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
192
+
193
+ # Process 0
194
+ if RANK in {-1, 0}:
195
+ val_loader = create_dataloader(val_path,
196
+ imgsz,
197
+ batch_size // WORLD_SIZE * 2,
198
+ gs,
199
+ single_cls,
200
+ hyp=hyp,
201
+ cache=None if noval else opt.cache,
202
+ rect=True,
203
+ rank=-1,
204
+ workers=workers * 2,
205
+ pad=0.5,
206
+ mask_downsample_ratio=mask_ratio,
207
+ overlap_mask=overlap,
208
+ prefix=colorstr('val: '))[0]
209
+
210
+ if not resume:
211
+ #if not opt.noautoanchor:
212
+ # check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
213
+ model.half().float() # pre-reduce anchor precision
214
+
215
+ if plots:
216
+ plot_labels(labels, names, save_dir)
217
+ # callbacks.run('on_pretrain_routine_end', labels, names)
218
+
219
+ # DDP mode
220
+ if cuda and RANK != -1:
221
+ model = smart_DDP(model)
222
+
223
+ # Model attributes
224
+ nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
225
+ hyp['box'] *= 3 / nl # scale to layers
226
+ hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
227
+ hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
228
+ hyp['label_smoothing'] = opt.label_smoothing
229
+ model.nc = nc # attach number of classes to model
230
+ model.hyp = hyp # attach hyperparameters to model
231
+ model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
232
+ model.names = names
233
+
234
+ # Start training
235
+ t0 = time.time()
236
+ nb = len(train_loader) # number of batches
237
+ nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
238
+ # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
239
+ last_opt_step = -1
240
+ maps = np.zeros(nc) # mAP per class
241
+ results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
242
+ scheduler.last_epoch = start_epoch - 1 # do not move
243
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
244
+ stopper, stop = EarlyStopping(patience=opt.patience), False
245
+ compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
246
+ # callbacks.run('on_train_start')
247
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
248
+ f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
249
+ f"Logging results to {colorstr('bold', save_dir)}\n"
250
+ f'Starting training for {epochs} epochs...')
251
+ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
252
+ # callbacks.run('on_train_epoch_start')
253
+ model.train()
254
+
255
+ # Update image weights (optional, single-GPU only)
256
+ if opt.image_weights:
257
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
258
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
259
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
260
+ if epoch == (epochs - opt.close_mosaic):
261
+ LOGGER.info("Closing dataloader mosaic")
262
+ dataset.mosaic = False
263
+
264
+ # Update mosaic border (optional)
265
+ # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
266
+ # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
267
+
268
+ mloss = torch.zeros(6, device=device) # mean losses
269
+ if RANK != -1:
270
+ train_loader.sampler.set_epoch(epoch)
271
+ pbar = enumerate(train_loader)
272
+ LOGGER.info(('\n' + '%11s' * 10) %
273
+ ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss', 'fcl_loss', 'dic_loss', 'Instances', 'Size'))
274
+ if RANK in {-1, 0}:
275
+ pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
276
+ optimizer.zero_grad()
277
+ for i, (imgs, targets, paths, _, masks, semasks) in pbar: # batch ------------------------------------------------------
278
+ # callbacks.run('on_train_batch_start')
279
+ #print(imgs.shape)
280
+ #print(semasks.shape)
281
+ #print(masks.shape)
282
+ ni = i + nb * epoch # number integrated batches (since train start)
283
+ imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
284
+
285
+ # Warmup
286
+ if ni <= nw:
287
+ xi = [0, nw] # x interp
288
+ # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
289
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
290
+ for j, x in enumerate(optimizer.param_groups):
291
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
292
+ x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
293
+ if 'momentum' in x:
294
+ x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
295
+
296
+ # Multi-scale
297
+ if opt.multi_scale:
298
+ sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
299
+ sf = sz / max(imgs.shape[2:]) # scale factor
300
+ if sf != 1:
301
+ ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
302
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
303
+
304
+ # Forward
305
+ with torch.cuda.amp.autocast(amp):
306
+ pred = model(imgs) # forward
307
+ loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float(),
308
+ semasks=semasks.to(device).float())
309
+ if RANK != -1:
310
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
311
+ if opt.quad:
312
+ loss *= 4.
313
+
314
+ # Backward
315
+ torch.use_deterministic_algorithms(False)
316
+ scaler.scale(loss).backward()
317
+
318
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
319
+ if ni - last_opt_step >= accumulate:
320
+ scaler.unscale_(optimizer) # unscale gradients
321
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
322
+ scaler.step(optimizer) # optimizer.step
323
+ scaler.update()
324
+ optimizer.zero_grad()
325
+ if ema:
326
+ ema.update(model)
327
+ last_opt_step = ni
328
+
329
+ # Log
330
+ if RANK in {-1, 0}:
331
+ mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
332
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
333
+ pbar.set_description(('%11s' * 2 + '%11.4g' * 8) %
334
+ (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
335
+ # callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
336
+ # if callbacks.stop_training:
337
+ # return
338
+
339
+ # Mosaic plots
340
+ if plots:
341
+ if ni < 10:
342
+ plot_images_and_masks(imgs, targets, masks, semasks, paths, save_dir / f"train_batch{ni}.jpg")
343
+ if ni == 10:
344
+ files = sorted(save_dir.glob('train*.jpg'))
345
+ logger.log_images(files, "Mosaics", epoch)
346
+ # end batch ------------------------------------------------------------------------------------------------
347
+
348
+ # Scheduler
349
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
350
+ scheduler.step()
351
+
352
+ if RANK in {-1, 0}:
353
+ # mAP
354
+ # callbacks.run('on_train_epoch_end', epoch=epoch)
355
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
356
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
357
+ if not noval or final_epoch: # Calculate mAP
358
+ if (opt.save_period > 0 and epoch % opt.save_period == 0) or (epoch > (epochs - 2 * opt.close_mosaic)):
359
+ results, maps, _ = validate.run(data_dict,
360
+ batch_size=batch_size // WORLD_SIZE * 2,
361
+ imgsz=imgsz,
362
+ half=amp,
363
+ model=ema.ema,
364
+ single_cls=single_cls,
365
+ dataloader=val_loader,
366
+ save_dir=save_dir,
367
+ plots=False,
368
+ callbacks=callbacks,
369
+ compute_loss=compute_loss,
370
+ mask_downsample_ratio=mask_ratio,
371
+ overlap=overlap)
372
+
373
+ # Update best mAP
374
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
375
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
376
+ if fi > best_fitness:
377
+ best_fitness = fi
378
+ log_vals = list(mloss) + list(results) + lr
379
+ # callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
380
+ # Log val metrics and media
381
+ metrics_dict = dict(zip(KEYS, log_vals))
382
+ logger.log_metrics(metrics_dict, epoch)
383
+
384
+ # Save model
385
+ if (not nosave) or (final_epoch and not evolve): # if save
386
+ ckpt = {
387
+ 'epoch': epoch,
388
+ 'best_fitness': best_fitness,
389
+ 'model': deepcopy(de_parallel(model)).half(),
390
+ 'ema': deepcopy(ema.ema).half(),
391
+ 'updates': ema.updates,
392
+ 'optimizer': optimizer.state_dict(),
393
+ 'opt': vars(opt),
394
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
395
+ 'date': datetime.now().isoformat()}
396
+
397
+ # Save last, best and delete
398
+ torch.save(ckpt, last)
399
+ if best_fitness == fi:
400
+ torch.save(ckpt, best)
401
+ if opt.save_period > 0 and epoch % opt.save_period == 0:
402
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
403
+ logger.log_model(w / f'epoch{epoch}.pt')
404
+ del ckpt
405
+ # callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
406
+
407
+ # EarlyStopping
408
+ if RANK != -1: # if DDP training
409
+ broadcast_list = [stop if RANK == 0 else None]
410
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
411
+ if RANK != 0:
412
+ stop = broadcast_list[0]
413
+ if stop:
414
+ break # must break all DDP ranks
415
+
416
+ # end epoch ----------------------------------------------------------------------------------------------------
417
+ # end training -----------------------------------------------------------------------------------------------------
418
+ if RANK in {-1, 0}:
419
+ LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
420
+ for f in last, best:
421
+ if f.exists():
422
+ strip_optimizer(f) # strip optimizers
423
+ if f is best:
424
+ LOGGER.info(f'\nValidating {f}...')
425
+ results, _, _ = validate.run(
426
+ data_dict,
427
+ batch_size=batch_size // WORLD_SIZE * 2,
428
+ imgsz=imgsz,
429
+ model=attempt_load(f, device).half(),
430
+ iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
431
+ single_cls=single_cls,
432
+ dataloader=val_loader,
433
+ save_dir=save_dir,
434
+ save_json=is_coco,
435
+ verbose=True,
436
+ plots=plots,
437
+ callbacks=callbacks,
438
+ compute_loss=compute_loss,
439
+ mask_downsample_ratio=mask_ratio,
440
+ overlap=overlap) # val best model with plots
441
+ if is_coco:
442
+ # callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
443
+ metrics_dict = dict(zip(KEYS, list(mloss) + list(results) + lr))
444
+ logger.log_metrics(metrics_dict, epoch)
445
+
446
+ # callbacks.run('on_train_end', last, best, epoch, results)
447
+ # on train end callback using genericLogger
448
+ logger.log_metrics(dict(zip(KEYS[6:22], results)), epochs)
449
+ if not opt.evolve:
450
+ logger.log_model(best, epoch)
451
+ if plots:
452
+ plot_results_with_masks(file=save_dir / 'results.csv') # save results.png
453
+ files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
454
+ files = [(save_dir / f) for f in files if (save_dir / f).exists()] # filter
455
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
456
+ logger.log_images(files, "Results", epoch + 1)
457
+ logger.log_images(sorted(save_dir.glob('val*.jpg')), "Validation", epoch + 1)
458
+ torch.cuda.empty_cache()
459
+ return results
460
+
461
+
462
+ def parse_opt(known=False):
463
+ parser = argparse.ArgumentParser()
464
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo-pan.pt', help='initial weights path')
465
+ parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
466
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path')
467
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
468
+ parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
469
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
470
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
471
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
472
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
473
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
474
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
475
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
476
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
477
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
478
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
479
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
480
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
481
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
482
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
483
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
484
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW', 'LION'], default='SGD', help='optimizer')
485
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
486
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
487
+ parser.add_argument('--project', default=ROOT / 'runs/train-pan', help='save to project/name')
488
+ parser.add_argument('--name', default='exp', help='save to project/name')
489
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
490
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
491
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
492
+ parser.add_argument('--flat-cos-lr', action='store_true', help='cosine LR scheduler')
493
+ parser.add_argument('--fixed-lr', action='store_true', help='fixed LR scheduler')
494
+ parser.add_argument('--poly-lr', action='store_true', help='fixed LR scheduler')
495
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
496
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
497
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
498
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
499
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
500
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
501
+ parser.add_argument('--close-mosaic', type=int, default=0, help='Experimental')
502
+
503
+ # Instance Segmentation Args
504
+ parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
505
+ parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')
506
+
507
+ return parser.parse_known_args()[0] if known else parser.parse_args()
508
+
509
+
510
+ def main(opt, callbacks=Callbacks()):
511
+ # Checks
512
+ if RANK in {-1, 0}:
513
+ print_args(vars(opt))
514
+ #check_git_status()
515
+ #check_requirements()
516
+
517
+ # Resume
518
+ if opt.resume and not opt.evolve: # resume from specified or most recent last.pt
519
+ last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
520
+ opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
521
+ opt_data = opt.data # original dataset
522
+ if opt_yaml.is_file():
523
+ with open(opt_yaml, errors='ignore') as f:
524
+ d = yaml.safe_load(f)
525
+ else:
526
+ d = torch.load(last, map_location='cpu')['opt']
527
+ opt = argparse.Namespace(**d) # replace
528
+ opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
529
+ if is_url(opt_data):
530
+ opt.data = check_file(opt_data) # avoid HUB resume auth timeout
531
+ else:
532
+ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
533
+ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
534
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
535
+ if opt.evolve:
536
+ if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
537
+ opt.project = str(ROOT / 'runs/evolve')
538
+ opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
539
+ if opt.name == 'cfg':
540
+ opt.name = Path(opt.cfg).stem # use model.yaml as name
541
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
542
+
543
+ # DDP mode
544
+ device = select_device(opt.device, batch_size=opt.batch_size)
545
+ if LOCAL_RANK != -1:
546
+ msg = 'is not compatible with YOLO Multi-GPU DDP training'
547
+ assert not opt.image_weights, f'--image-weights {msg}'
548
+ assert not opt.evolve, f'--evolve {msg}'
549
+ assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
550
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
551
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
552
+ torch.cuda.set_device(LOCAL_RANK)
553
+ device = torch.device('cuda', LOCAL_RANK)
554
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
555
+
556
+ # Train
557
+ if not opt.evolve:
558
+ train(opt.hyp, opt, device, callbacks)
559
+
560
+ # Evolve hyperparameters (optional)
561
+ else:
562
+ # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
563
+ meta = {
564
+ 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
565
+ 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
566
+ 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
567
+ 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
568
+ 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
569
+ 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
570
+ 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
571
+ 'box': (1, 0.02, 0.2), # box loss gain
572
+ 'cls': (1, 0.2, 4.0), # cls loss gain
573
+ 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
574
+ 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
575
+ 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
576
+ 'iou_t': (0, 0.1, 0.7), # IoU training threshold
577
+ 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
578
+ 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
579
+ 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
580
+ 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
581
+ 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
582
+ 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
583
+ 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
584
+ 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
585
+ 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
586
+ 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
587
+ 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
588
+ 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
589
+ 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
590
+ 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
591
+ 'mixup': (1, 0.0, 1.0), # image mixup (probability)
592
+ 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
593
+
594
+ with open(opt.hyp, errors='ignore') as f:
595
+ hyp = yaml.safe_load(f) # load hyps dict
596
+ if 'anchors' not in hyp: # anchors commented in hyp.yaml
597
+ hyp['anchors'] = 3
598
+ if opt.noautoanchor:
599
+ del hyp['anchors'], meta['anchors']
600
+ opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
601
+ # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
602
+ evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
603
+ if opt.bucket:
604
+ os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
605
+
606
+ for _ in range(opt.evolve): # generations to evolve
607
+ if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
608
+ # Select parent(s)
609
+ parent = 'single' # parent selection method: 'single' or 'weighted'
610
+ x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
611
+ n = min(5, len(x)) # number of previous results to consider
612
+ x = x[np.argsort(-fitness(x))][:n] # top n mutations
613
+ w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
614
+ if parent == 'single' or len(x) == 1:
615
+ # x = x[random.randint(0, n - 1)] # random selection
616
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
617
+ elif parent == 'weighted':
618
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
619
+
620
+ # Mutate
621
+ mp, s = 0.8, 0.2 # mutation probability, sigma
622
+ npr = np.random
623
+ npr.seed(int(time.time()))
624
+ g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
625
+ ng = len(meta)
626
+ v = np.ones(ng)
627
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
628
+ v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
629
+ for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
630
+ hyp[k] = float(x[i + 7] * v[i]) # mutate
631
+
632
+ # Constrain to limits
633
+ for k, v in meta.items():
634
+ hyp[k] = max(hyp[k], v[1]) # lower limit
635
+ hyp[k] = min(hyp[k], v[2]) # upper limit
636
+ hyp[k] = round(hyp[k], 5) # significant digits
637
+
638
+ # Train mutation
639
+ results = train(hyp.copy(), opt, device, callbacks)
640
+ callbacks = Callbacks()
641
+ # Write mutation results
642
+ print_mutation(KEYS, results, hyp.copy(), save_dir, opt.bucket)
643
+
644
+ # Plot results
645
+ plot_evolve(evolve_csv)
646
+ LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
647
+ f"Results saved to {colorstr('bold', save_dir)}\n"
648
+ f'Usage example: $ python train.py --hyp {evolve_yaml}')
649
+
650
+
651
+ def run(**kwargs):
652
+ # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolo.pt')
653
+ opt = parse_opt(True)
654
+ for k, v in kwargs.items():
655
+ setattr(opt, k, v)
656
+ main(opt)
657
+ return opt
658
+
659
+
660
+ if __name__ == "__main__":
661
+ opt = parse_opt()
662
+ main(opt)
yolov9/panoptic/val.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from multiprocessing.pool import ThreadPool
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ FILE = Path(__file__).resolve()
13
+ ROOT = FILE.parents[1] # YOLO root directory
14
+ if str(ROOT) not in sys.path:
15
+ sys.path.append(str(ROOT)) # add ROOT to PATH
16
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
17
+
18
+ import torch.nn.functional as F
19
+ import torchvision.transforms as transforms
20
+ from pycocotools import mask as maskUtils
21
+ from models.common import DetectMultiBackend
22
+ from models.yolo import SegmentationModel
23
+ from utils.callbacks import Callbacks
24
+ from utils.coco_utils import getCocoIds, getMappingId, getMappingIndex
25
+ from utils.general import (LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, Profile, check_dataset, check_img_size,
26
+ check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path,
27
+ non_max_suppression, print_args, scale_boxes, xywh2xyxy, xyxy2xywh)
28
+ from utils.metrics import ConfusionMatrix, box_iou
29
+ from utils.plots import output_to_target, plot_val_study
30
+ from utils.panoptic.dataloaders import create_dataloader
31
+ from utils.panoptic.general import mask_iou, process_mask, process_mask_upsample, scale_image
32
+ from utils.panoptic.metrics import Metrics, ap_per_class_box_and_mask, Semantic_Metrics
33
+ from utils.panoptic.plots import plot_images_and_masks
34
+ from utils.torch_utils import de_parallel, select_device, smart_inference_mode
35
+
36
+
37
+ def save_one_txt(predn, save_conf, shape, file):
38
+ # Save one txt result
39
+ gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
40
+ for *xyxy, conf, cls in predn.tolist():
41
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
42
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
43
+ with open(file, 'a') as f:
44
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
45
+
46
+
47
+ def save_one_json(predn, jdict, path, class_map, pred_masks):
48
+ # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
49
+ from pycocotools.mask import encode
50
+
51
+ def single_encode(x):
52
+ rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
53
+ rle["counts"] = rle["counts"].decode("utf-8")
54
+ return rle
55
+
56
+ image_id = int(path.stem) if path.stem.isnumeric() else path.stem
57
+ box = xyxy2xywh(predn[:, :4]) # xywh
58
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
59
+ pred_masks = np.transpose(pred_masks, (2, 0, 1))
60
+ with ThreadPool(NUM_THREADS) as pool:
61
+ rles = pool.map(single_encode, pred_masks)
62
+ for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
63
+ jdict.append({
64
+ 'image_id': image_id,
65
+ 'category_id': class_map[int(p[5])],
66
+ 'bbox': [round(x, 3) for x in b],
67
+ 'score': round(p[4], 5),
68
+ 'segmentation': rles[i]})
69
+
70
+
71
+ def process_batch(detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
72
+ """
73
+ Return correct prediction matrix
74
+ Arguments:
75
+ detections (array[N, 6]), x1, y1, x2, y2, conf, class
76
+ labels (array[M, 5]), class, x1, y1, x2, y2
77
+ Returns:
78
+ correct (array[N, 10]), for 10 IoU levels
79
+ """
80
+ if masks:
81
+ if overlap:
82
+ nl = len(labels)
83
+ index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
84
+ gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
85
+ gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
86
+ if gt_masks.shape[1:] != pred_masks.shape[1:]:
87
+ gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
88
+ gt_masks = gt_masks.gt_(0.5)
89
+ iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
90
+ else: # boxes
91
+ iou = box_iou(labels[:, 1:], detections[:, :4])
92
+
93
+ correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
94
+ correct_class = labels[:, 0:1] == detections[:, 5]
95
+ for i in range(len(iouv)):
96
+ x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
97
+ if x[0].shape[0]:
98
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
99
+ if x[0].shape[0] > 1:
100
+ matches = matches[matches[:, 2].argsort()[::-1]]
101
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
102
+ # matches = matches[matches[:, 2].argsort()[::-1]]
103
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
104
+ correct[matches[:, 1].astype(int), i] = True
105
+ return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
106
+
107
+
108
+ @smart_inference_mode()
109
+ def run(
110
+ data,
111
+ weights=None, # model.pt path(s)
112
+ batch_size=32, # batch size
113
+ imgsz=640, # inference size (pixels)
114
+ conf_thres=0.001, # confidence threshold
115
+ iou_thres=0.6, # NMS IoU threshold
116
+ max_det=300, # maximum detections per image
117
+ task='val', # train, val, test, speed or study
118
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
119
+ workers=8, # max dataloader workers (per RANK in DDP mode)
120
+ single_cls=False, # treat as single-class dataset
121
+ augment=False, # augmented inference
122
+ verbose=False, # verbose output
123
+ save_txt=False, # save results to *.txt
124
+ save_hybrid=False, # save label+prediction hybrid results to *.txt
125
+ save_conf=False, # save confidences in --save-txt labels
126
+ save_json=False, # save a COCO-JSON results file
127
+ project=ROOT / 'runs/val-pan', # save to project/name
128
+ name='exp', # save to project/name
129
+ exist_ok=False, # existing project/name ok, do not increment
130
+ half=True, # use FP16 half-precision inference
131
+ dnn=False, # use OpenCV DNN for ONNX inference
132
+ model=None,
133
+ dataloader=None,
134
+ save_dir=Path(''),
135
+ plots=True,
136
+ overlap=False,
137
+ mask_downsample_ratio=1,
138
+ compute_loss=None,
139
+ callbacks=Callbacks(),
140
+ ):
141
+ if save_json:
142
+ check_requirements(['pycocotools'])
143
+ process = process_mask_upsample # more accurate
144
+ else:
145
+ process = process_mask # faster
146
+
147
+ # Initialize/load model and set device
148
+ training = model is not None
149
+ if training: # called by train.py
150
+ device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
151
+ half &= device.type != 'cpu' # half precision only supported on CUDA
152
+ model.half() if half else model.float()
153
+ nm = de_parallel(model).model[-1].nm # number of masks
154
+ else: # called directly
155
+ device = select_device(device, batch_size=batch_size)
156
+
157
+ # Directories
158
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
159
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
160
+
161
+ # Load model
162
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
163
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
164
+ imgsz = check_img_size(imgsz, s=stride) # check image size
165
+ half = model.fp16 # FP16 supported on limited backends with CUDA
166
+ nm = de_parallel(model).model.model[-1].nm if isinstance(model, SegmentationModel) else 32 # number of masks
167
+ if engine:
168
+ batch_size = model.batch_size
169
+ else:
170
+ device = model.device
171
+ if not (pt or jit):
172
+ batch_size = 1 # export.py models default to batch-size 1
173
+ LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
174
+
175
+ # Data
176
+ data = check_dataset(data) # check
177
+
178
+ # Configure
179
+ model.eval()
180
+ cuda = device.type != 'cpu'
181
+ #is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'coco{os.sep}val2017.txt') # COCO dataset
182
+ is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'val2017.txt') # COCO dataset
183
+ nc = 1 if single_cls else int(data['nc']) # number of classes
184
+ stuff_names = data.get('stuff_names', []) # names of stuff classes
185
+ stuff_nc = len(stuff_names) # number of stuff classes
186
+ iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95
187
+ niou = iouv.numel()
188
+
189
+ # Semantic Segmentation
190
+ img_id_list = []
191
+
192
+ # Dataloader
193
+ if not training:
194
+ if pt and not single_cls: # check --weights are trained on --data
195
+ ncm = model.model.nc
196
+ assert ncm == nc, f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' \
197
+ f'classes). Pass correct combination of --weights and --data that are trained together.'
198
+ model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
199
+ pad, rect = (0.0, False) if task == 'speed' else (0.5, pt) # square inference for benchmarks
200
+ task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
201
+ dataloader = create_dataloader(data[task],
202
+ imgsz,
203
+ batch_size,
204
+ stride,
205
+ single_cls,
206
+ pad=pad,
207
+ rect=rect,
208
+ workers=workers,
209
+ prefix=colorstr(f'{task}: '),
210
+ overlap_mask=overlap,
211
+ mask_downsample_ratio=mask_downsample_ratio)[0]
212
+
213
+ seen = 0
214
+ confusion_matrix = ConfusionMatrix(nc=nc)
215
+ names = model.names if hasattr(model, 'names') else model.module.names # get class names
216
+ if isinstance(names, (list, tuple)): # old format
217
+ names = dict(enumerate(names))
218
+ class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
219
+ s = ('%22s' + '%11s' * 12) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", "R",
220
+ "mAP50", "mAP50-95)", 'S(MIoU', 'FWIoU)')
221
+ dt = Profile(), Profile(), Profile()
222
+ metrics = Metrics()
223
+ semantic_metrics = Semantic_Metrics(nc = (nc + stuff_nc), device = device)
224
+ loss = torch.zeros(6, device=device)
225
+ jdict, stats = [], []
226
+ semantic_jdict = []
227
+ # callbacks.run('on_val_start')
228
+ pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar
229
+ for batch_i, (im, targets, paths, shapes, masks, semasks) in enumerate(pbar):
230
+ # callbacks.run('on_val_batch_start')
231
+ with dt[0]:
232
+ if cuda:
233
+ im = im.to(device, non_blocking=True)
234
+ targets = targets.to(device)
235
+ masks = masks.to(device)
236
+ semasks = semasks.to(device)
237
+ masks = masks.float()
238
+ semasks = semasks.float()
239
+ im = im.half() if half else im.float() # uint8 to fp16/32
240
+ im /= 255 # 0 - 255 to 0.0 - 1.0
241
+ nb, _, height, width = im.shape # batch size, channels, height, width
242
+
243
+ # Inference
244
+ with dt[1]:
245
+ preds, train_out = model(im)# if compute_loss else (*model(im, augment=augment)[:2], None)
246
+ #train_out, preds, protos = p if len(p) == 3 else p[1]
247
+ #preds = p
248
+ #train_out = p[1][0] if len(p[1]) == 3 else p[0]
249
+ # protos = train_out[-1]
250
+ #print(preds.shape)
251
+ #print(train_out[0].shape)
252
+ #print(train_out[1].shape)
253
+ #print(train_out[2].shape)
254
+ _, pred_masks, protos, psemasks = train_out
255
+
256
+ # Loss
257
+ if compute_loss:
258
+ loss += compute_loss(train_out, targets, masks, semasks = semasks)[1] # box, obj, cls
259
+
260
+ # NMS
261
+ targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
262
+ lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
263
+ with dt[2]:
264
+ preds = non_max_suppression(preds,
265
+ conf_thres,
266
+ iou_thres,
267
+ labels=lb,
268
+ multi_label=True,
269
+ agnostic=single_cls,
270
+ max_det=max_det,
271
+ nm=nm)
272
+
273
+ # Metrics
274
+ plot_masks = [] # masks for plotting
275
+ plot_semasks = [] # masks for plotting
276
+
277
+ if training:
278
+ semantic_metrics.update(psemasks, semasks)
279
+ else:
280
+ _, _, smh, smw = semasks.shape
281
+ semantic_metrics.update(torch.nn.functional.interpolate(psemasks, size = (smh, smw), mode = 'bilinear', align_corners = False), semasks)
282
+
283
+ if plots and batch_i < 3:
284
+ plot_semasks.append(psemasks.clone().detach().cpu())
285
+
286
+ for si, (pred, proto, psemask) in enumerate(zip(preds, protos, psemasks)):
287
+ labels = targets[targets[:, 0] == si, 1:]
288
+ nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
289
+ path, shape = Path(paths[si]), shapes[si][0]
290
+ image_id = path.stem
291
+ img_id_list.append(image_id)
292
+ correct_masks = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
293
+ correct_bboxes = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
294
+ seen += 1
295
+
296
+ if npr == 0:
297
+ if nl:
298
+ stats.append((correct_masks, correct_bboxes, *torch.zeros((2, 0), device=device), labels[:, 0]))
299
+ if plots:
300
+ confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
301
+ else:
302
+ # Masks
303
+ midx = [si] if overlap else targets[:, 0] == si
304
+ gt_masks = masks[midx]
305
+ pred_masks = process(proto, pred[:, 6:], pred[:, :4], shape=im[si].shape[1:])
306
+
307
+ # Predictions
308
+ if single_cls:
309
+ pred[:, 5] = 0
310
+ predn = pred.clone()
311
+ scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
312
+
313
+ # Evaluate
314
+ if nl:
315
+ tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
316
+ scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
317
+ labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
318
+ correct_bboxes = process_batch(predn, labelsn, iouv)
319
+ correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True)
320
+ if plots:
321
+ confusion_matrix.process_batch(predn, labelsn)
322
+ stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
323
+
324
+ pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
325
+ if plots and batch_i < 3:
326
+ plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
327
+
328
+ # Save/log
329
+ if save_txt:
330
+ save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
331
+ if save_json:
332
+ pred_masks = scale_image(im[si].shape[1:],
333
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
334
+ save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary
335
+ # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
336
+
337
+ # Semantic Segmentation
338
+ h0, w0 = shape
339
+
340
+ # resize
341
+ _, mask_h, mask_w = psemask.shape
342
+ h_ratio = mask_h / h0
343
+ w_ratio = mask_w / w0
344
+
345
+ if h_ratio == w_ratio:
346
+ psemask = torch.nn.functional.interpolate(psemask[None, :], size = (h0, w0), mode = 'bilinear', align_corners = False)
347
+ else:
348
+ transform = transforms.CenterCrop((h0, w0))
349
+
350
+ if (1 != h_ratio) and (1 != w_ratio):
351
+ h_new = h0 if (h_ratio < w_ratio) else int(mask_h / w_ratio)
352
+ w_new = w0 if (h_ratio > w_ratio) else int(mask_w / h_ratio)
353
+ psemask = torch.nn.functional.interpolate(psemask[None, :], size = (h_new, w_new), mode = 'bilinear', align_corners = False)
354
+
355
+ psemask = transform(psemask)
356
+
357
+ psemask = torch.squeeze(psemask)
358
+
359
+ nc, h, w = psemask.shape
360
+
361
+ semantic_mask = torch.flatten(psemask, start_dim = 1).permute(1, 0) # class x h x w -> (h x w) x class
362
+
363
+ max_idx = semantic_mask.argmax(1)
364
+ output_masks = torch.zeros(semantic_mask.shape).scatter(1, max_idx.cpu().unsqueeze(1), 1.0) # one hot: (h x w) x class
365
+ output_masks = torch.reshape(output_masks.permute(1, 0), (nc, h, w)) # (h x w) x class -> class x h x w
366
+ psemask = output_masks.to(device = device)
367
+
368
+ # TODO: check is_coco
369
+ instances_ids = getCocoIds(name = 'instances')
370
+ stuff_mask = torch.zeros((h, w), device = device)
371
+ check_semantic_mask = False
372
+ for idx, pred_semantic_mask in enumerate(psemask):
373
+ category_id = int(getMappingId(idx))
374
+ if 183 == category_id:
375
+ # set all non-stuff pixels to other
376
+ pred_semantic_mask = (torch.logical_xor(stuff_mask, torch.ones((h, w), device = device))).int()
377
+
378
+ # ignore the classes which all zeros / unlabeled class
379
+ if (0 >= torch.max(pred_semantic_mask)) or (0 >= category_id):
380
+ continue
381
+
382
+ if category_id not in instances_ids:
383
+ # record all stuff mask
384
+ stuff_mask = torch.logical_or(stuff_mask, pred_semantic_mask)
385
+
386
+ if (category_id not in instances_ids):
387
+ rle = maskUtils.encode(np.asfortranarray(pred_semantic_mask.cpu(), dtype = np.uint8))
388
+ rle['counts'] = rle['counts'].decode('utf-8')
389
+
390
+ temp_d = {
391
+ 'image_id': int(image_id) if image_id.isnumeric() else image_id,
392
+ 'category_id': category_id,
393
+ 'segmentation': rle,
394
+ 'score': 1
395
+ }
396
+
397
+ semantic_jdict.append(temp_d)
398
+ check_semantic_mask = True
399
+
400
+ if not check_semantic_mask:
401
+ # append a other mask for evaluation if the image without any mask
402
+ other_mask = (torch.ones((h, w), device = device)).int()
403
+
404
+ rle = maskUtils.encode(np.asfortranarray(other_mask.cpu(), dtype = np.uint8))
405
+ rle['counts'] = rle['counts'].decode('utf-8')
406
+
407
+ temp_d = {
408
+ 'image_id': int(image_id) if image_id.isnumeric() else image_id,
409
+ 'category_id': 183,
410
+ 'segmentation': rle,
411
+ 'score': 1
412
+ }
413
+
414
+ semantic_jdict.append(temp_d)
415
+
416
+ # Plot images
417
+ if plots and batch_i < 3:
418
+ if len(plot_masks):
419
+ plot_masks = torch.cat(plot_masks, dim=0)
420
+ if len(plot_semasks):
421
+ plot_semasks = torch.cat(plot_semasks, dim = 0)
422
+ plot_images_and_masks(im, targets, masks, semasks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
423
+ plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, plot_semasks, paths,
424
+ save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
425
+
426
+ # callbacks.run('on_val_batch_end')
427
+
428
+ # Compute metrics
429
+ stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
430
+ if len(stats) and stats[0].any():
431
+ results = ap_per_class_box_and_mask(*stats, plot=plots, save_dir=save_dir, names=names)
432
+ metrics.update(results)
433
+ nt = np.bincount(stats[4].astype(int), minlength=nc) # number of targets per class
434
+
435
+ # Print results
436
+ pf = '%22s' + '%11i' * 2 + '%11.3g' * 10 # print format
437
+ LOGGER.info(pf % ("all", seen, nt.sum(), *metrics.mean_results(), *semantic_metrics.results()))
438
+ if nt.sum() == 0:
439
+ LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')
440
+
441
+ # Print results per class
442
+ if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
443
+ for i, c in enumerate(metrics.ap_class_index):
444
+ LOGGER.info(pf % (names[c], seen, nt[c], *metrics.class_result(i), *semantic_metrics.results()))
445
+
446
+ # Print speeds
447
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
448
+ if not training:
449
+ shape = (batch_size, 3, imgsz, imgsz)
450
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
451
+
452
+ # Plots
453
+ if plots:
454
+ confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
455
+ # callbacks.run('on_val_end')
456
+
457
+ mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask = metrics.mean_results()
458
+ miou_sem, fwiou_sem = semantic_metrics.results()
459
+ semantic_metrics.reset()
460
+
461
+ # Save JSON
462
+ if save_json and len(jdict):
463
+ w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
464
+ anno_path = Path(data.get('path', '../coco'))
465
+ anno_json = str(anno_path / 'annotations/instances_val2017.json') # annotations json
466
+ pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
467
+ LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
468
+ with open(pred_json, 'w') as f:
469
+ json.dump(jdict, f)
470
+
471
+ semantic_anno_json = str(anno_path / 'annotations/stuff_val2017.json') # annotations json
472
+ semantic_pred_json = str(save_dir / f"{w}_predictions_stuff.json") # predictions json
473
+ LOGGER.info(f'\nsaving {semantic_pred_json}...')
474
+ with open(semantic_pred_json, 'w') as f:
475
+ json.dump(semantic_jdict, f)
476
+
477
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
478
+ from pycocotools.coco import COCO
479
+ from pycocotools.cocoeval import COCOeval
480
+
481
+ anno = COCO(anno_json) # init annotations api
482
+ pred = anno.loadRes(pred_json) # init predictions api
483
+ results = []
484
+ for eval in COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm'):
485
+ if is_coco:
486
+ eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # img ID to evaluate
487
+ eval.evaluate()
488
+ eval.accumulate()
489
+ eval.summarize()
490
+ results.extend(eval.stats[:2]) # update results (mAP@0.5:0.95, mAP@0.5)
491
+ map_bbox, map50_bbox, map_mask, map50_mask = results
492
+
493
+ # Semantic Segmentation
494
+ from utils.stuff_seg.cocostuffeval import COCOStuffeval
495
+
496
+ LOGGER.info(f'\nEvaluating pycocotools stuff... ')
497
+ imgIds = [int(x) for x in img_id_list]
498
+
499
+ stuffGt = COCO(semantic_anno_json) # initialize COCO ground truth api
500
+ stuffDt = stuffGt.loadRes(semantic_pred_json) # initialize COCO pred api
501
+
502
+ cocoStuffEval = COCOStuffeval(stuffGt, stuffDt)
503
+ cocoStuffEval.params.imgIds = imgIds # image IDs to evaluate
504
+ cocoStuffEval.evaluate()
505
+ stats, statsClass = cocoStuffEval.summarize()
506
+ stuffIds = getCocoIds(name = 'stuff')
507
+ title = ' {:<5} | {:^6} | {:^6} '.format('class', 'iou', 'macc') if (0 >= len(stuff_names)) else \
508
+ ' {:<5} | {:<20} | {:^6} | {:^6} '.format('class', 'class name', 'iou', 'macc')
509
+ print(title)
510
+ for idx, (iou, macc) in enumerate(zip(statsClass['ious'], statsClass['maccs'])):
511
+ id = (idx + 1)
512
+ if id not in stuffIds:
513
+ continue
514
+ content = ' {:<5} | {:0.4f} | {:0.4f} '.format(str(id), iou, macc) if (0 >= len(stuff_names)) else \
515
+ ' {:<5} | {:<20} | {:0.4f} | {:0.4f} '.format(str(id), str(stuff_names[getMappingIndex(id, name = 'stuff')]), iou, macc)
516
+ print(content)
517
+
518
+ except Exception as e:
519
+ LOGGER.info(f'pycocotools unable to run: {e}')
520
+
521
+ # Return results
522
+ model.float() # for training
523
+ if not training:
524
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
525
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
526
+ final_metric = mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask, miou_sem, fwiou_sem
527
+ return (*final_metric, *(loss.cpu() / len(dataloader)).tolist()), metrics.get_maps(nc), t
528
+
529
+
530
+ def parse_opt():
531
+ parser = argparse.ArgumentParser()
532
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-pan.yaml', help='dataset.yaml path')
533
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo-pan.pt', help='model path(s)')
534
+ parser.add_argument('--batch-size', type=int, default=32, help='batch size')
535
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
536
+ parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
537
+ parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold')
538
+ parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image')
539
+ parser.add_argument('--task', default='val', help='train, val, test, speed or study')
540
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
541
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
542
+ parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
543
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
544
+ parser.add_argument('--verbose', action='store_true', help='report mAP by class')
545
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
546
+ parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
547
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
548
+ parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file')
549
+ parser.add_argument('--project', default=ROOT / 'runs/val-pan', help='save results to project/name')
550
+ parser.add_argument('--name', default='exp', help='save to project/name')
551
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
552
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
553
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
554
+ opt = parser.parse_args()
555
+ opt.data = check_yaml(opt.data) # check YAML
556
+ # opt.save_json |= opt.data.endswith('coco.yaml')
557
+ opt.save_txt |= opt.save_hybrid
558
+ print_args(vars(opt))
559
+ return opt
560
+
561
+
562
+ def main(opt):
563
+ #check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
564
+
565
+ if opt.task in ('train', 'val', 'test'): # run normally
566
+ if opt.conf_thres > 0.001: # https://github.com/
567
+ LOGGER.warning(f'WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results')
568
+ if opt.save_hybrid:
569
+ LOGGER.warning('WARNING ⚠️ --save-hybrid returns high mAP from hybrid labels, not from predictions alone')
570
+ run(**vars(opt))
571
+
572
+ else:
573
+ weights = opt.weights if isinstance(opt.weights, list) else [opt.weights]
574
+ opt.half = torch.cuda.is_available() and opt.device != 'cpu' # FP16 for fastest results
575
+ if opt.task == 'speed': # speed benchmarks
576
+ # python val.py --task speed --data coco.yaml --batch 1 --weights yolo.pt...
577
+ opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False
578
+ for opt.weights in weights:
579
+ run(**vars(opt), plots=False)
580
+
581
+ elif opt.task == 'study': # speed vs mAP benchmarks
582
+ # python val.py --task study --data coco.yaml --iou 0.7 --weights yolo.pt...
583
+ for opt.weights in weights:
584
+ f = f'study_{Path(opt.data).stem}_{Path(opt.weights).stem}.txt' # filename to save to
585
+ x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis
586
+ for opt.imgsz in x: # img-size
587
+ LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...')
588
+ r, _, t = run(**vars(opt), plots=False)
589
+ y.append(r + t) # results and times
590
+ np.savetxt(f, y, fmt='%10.4g') # save
591
+ os.system('zip -r study.zip study_*.txt')
592
+ plot_val_study(x=x) # plot
593
+
594
+
595
+ if __name__ == "__main__":
596
+ opt = parse_opt()
597
+ main(opt)
yolov9/requirements.txt ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements
2
+ # Usage: pip install -r requirements.txt
3
+
4
+ # Base ------------------------------------------------------------------------
5
+ gitpython
6
+ ipython
7
+ matplotlib>=3.2.2
8
+ numpy>=1.18.5
9
+ opencv-python>=4.1.1
10
+ Pillow>=7.1.2
11
+ psutil
12
+ PyYAML>=5.3.1
13
+ requests>=2.23.0
14
+ scipy>=1.4.1
15
+ thop>=0.1.1
16
+ torch>=1.7.0
17
+ torchvision>=0.8.1
18
+ tqdm>=4.64.0
19
+ # protobuf<=3.20.1
20
+
21
+ # Logging ---------------------------------------------------------------------
22
+ tensorboard>=2.4.1
23
+ # clearml>=1.2.0
24
+ # comet
25
+
26
+ # Plotting --------------------------------------------------------------------
27
+ pandas>=1.1.4
28
+ seaborn>=0.11.0
29
+
30
+ # Export ----------------------------------------------------------------------
31
+ # coremltools>=6.0
32
+ # onnx>=1.9.0
33
+ # onnx-simplifier>=0.4.1
34
+ # nvidia-pyindex
35
+ # nvidia-tensorrt
36
+ # scikit-learn<=1.1.2
37
+ # tensorflow>=2.4.1
38
+ # tensorflowjs>=3.9.0
39
+ # openvino-dev
40
+
41
+ # Deploy ----------------------------------------------------------------------
42
+ # tritonclient[all]~=2.24.0
43
+
44
+ # Extras ----------------------------------------------------------------------
45
+ # mss
46
+ albumentations>=1.0.3
47
+ pycocotools>=2.0
yolov9/scripts/get_coco.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # COCO 2017 dataset http://cocodataset.org
3
+ # Download command: bash ./scripts/get_coco.sh
4
+
5
+ # Download/unzip labels
6
+ d='./' # unzip directory
7
+ url=https://github.com/ultralytics/yolov5/releases/download/v1.0/
8
+ f='coco2017labels-segments.zip' # or 'coco2017labels.zip', 68 MB
9
+ echo 'Downloading' $url$f ' ...'
10
+ curl -L $url$f -o $f && unzip -q $f -d $d && rm $f & # download, unzip, remove in background
11
+
12
+ # Download/unzip images
13
+ d='./coco/images' # unzip directory
14
+ url=http://images.cocodataset.org/zips/
15
+ f1='train2017.zip' # 19G, 118k images
16
+ f2='val2017.zip' # 1G, 5k images
17
+ f3='test2017.zip' # 7G, 41k images (optional)
18
+ for f in $f1 $f2 $f3; do
19
+ echo 'Downloading' $url$f '...'
20
+ curl -L $url$f -o $f && unzip -q $f -d $d && rm $f & # download, unzip, remove in background
21
+ done
22
+ wait # finish background tasks
yolov9/segment/predict.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[1] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
19
+ strip_optimizer, xyxy2xywh)
20
+ from utils.plots import Annotator, colors, save_one_box
21
+ from utils.segment.general import masks2segments, process_mask
22
+ from utils.torch_utils import select_device, smart_inference_mode
23
+
24
+
25
+ @smart_inference_mode()
26
+ def run(
27
+ weights=ROOT / 'yolo-seg.pt', # model.pt path(s)
28
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
29
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
30
+ imgsz=(640, 640), # inference size (height, width)
31
+ conf_thres=0.25, # confidence threshold
32
+ iou_thres=0.45, # NMS IOU threshold
33
+ max_det=1000, # maximum detections per image
34
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
35
+ view_img=False, # show results
36
+ save_txt=False, # save results to *.txt
37
+ save_conf=False, # save confidences in --save-txt labels
38
+ save_crop=False, # save cropped prediction boxes
39
+ nosave=False, # do not save images/videos
40
+ classes=None, # filter by class: --class 0, or --class 0 2 3
41
+ agnostic_nms=False, # class-agnostic NMS
42
+ augment=False, # augmented inference
43
+ visualize=False, # visualize features
44
+ update=False, # update all models
45
+ project=ROOT / 'runs/predict-seg', # save results to project/name
46
+ name='exp', # save results to project/name
47
+ exist_ok=False, # existing project/name ok, do not increment
48
+ line_thickness=3, # bounding box thickness (pixels)
49
+ hide_labels=False, # hide labels
50
+ hide_conf=False, # hide confidences
51
+ half=False, # use FP16 half-precision inference
52
+ dnn=False, # use OpenCV DNN for ONNX inference
53
+ vid_stride=1, # video frame-rate stride
54
+ retina_masks=False,
55
+ ):
56
+ source = str(source)
57
+ save_img = not nosave and not source.endswith('.txt') # save inference images
58
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
59
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
60
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
61
+ screenshot = source.lower().startswith('screen')
62
+ if is_url and is_file:
63
+ source = check_file(source) # download
64
+
65
+ # Directories
66
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
67
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
68
+
69
+ # Load model
70
+ device = select_device(device)
71
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
72
+ stride, names, pt = model.stride, model.names, model.pt
73
+ imgsz = check_img_size(imgsz, s=stride) # check image size
74
+
75
+ # Dataloader
76
+ bs = 1 # batch_size
77
+ if webcam:
78
+ view_img = check_imshow(warn=True)
79
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
80
+ bs = len(dataset)
81
+ elif screenshot:
82
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
83
+ else:
84
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
85
+ vid_path, vid_writer = [None] * bs, [None] * bs
86
+
87
+ # Run inference
88
+ model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
89
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
90
+ for path, im, im0s, vid_cap, s in dataset:
91
+ with dt[0]:
92
+ im = torch.from_numpy(im).to(model.device)
93
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
94
+ im /= 255 # 0 - 255 to 0.0 - 1.0
95
+ if len(im.shape) == 3:
96
+ im = im[None] # expand for batch dim
97
+
98
+ # Inference
99
+ with dt[1]:
100
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
101
+ pred, proto = model(im, augment=augment, visualize=visualize)[:2]
102
+
103
+ # NMS
104
+ with dt[2]:
105
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det, nm=32)
106
+
107
+ # Second-stage classifier (optional)
108
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
109
+
110
+ # Process predictions
111
+ for i, det in enumerate(pred): # per image
112
+ seen += 1
113
+ if webcam: # batch_size >= 1
114
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
115
+ s += f'{i}: '
116
+ else:
117
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
118
+
119
+ p = Path(p) # to Path
120
+ save_path = str(save_dir / p.name) # im.jpg
121
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
122
+ s += '%gx%g ' % im.shape[2:] # print string
123
+ imc = im0.copy() if save_crop else im0 # for save_crop
124
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
125
+ if len(det):
126
+ masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
127
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
128
+
129
+ # Segments
130
+ if save_txt:
131
+ segments = reversed(masks2segments(masks))
132
+ segments = [scale_segments(im.shape[2:], x, im0.shape, normalize=True) for x in segments]
133
+
134
+ # Print results
135
+ for c in det[:, 5].unique():
136
+ n = (det[:, 5] == c).sum() # detections per class
137
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
138
+
139
+ # Mask plotting
140
+ annotator.masks(masks,
141
+ colors=[colors(x, True) for x in det[:, 5]],
142
+ im_gpu=None if retina_masks else im[i])
143
+
144
+ # Write results
145
+ for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
146
+ if save_txt: # Write to file
147
+ segj = segments[j].reshape(-1) # (n,2) to (n*2)
148
+ line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
149
+ with open(f'{txt_path}.txt', 'a') as f:
150
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
151
+
152
+ if save_img or save_crop or view_img: # Add bbox to image
153
+ c = int(cls) # integer class
154
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
155
+ annotator.box_label(xyxy, label, color=colors(c, True))
156
+ # annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
157
+ if save_crop:
158
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
159
+
160
+ # Stream results
161
+ im0 = annotator.result()
162
+ if view_img:
163
+ if platform.system() == 'Linux' and p not in windows:
164
+ windows.append(p)
165
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
166
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
167
+ cv2.imshow(str(p), im0)
168
+ if cv2.waitKey(1) == ord('q'): # 1 millisecond
169
+ exit()
170
+
171
+ # Save results (image with detections)
172
+ if save_img:
173
+ if dataset.mode == 'image':
174
+ cv2.imwrite(save_path, im0)
175
+ else: # 'video' or 'stream'
176
+ if vid_path[i] != save_path: # new video
177
+ vid_path[i] = save_path
178
+ if isinstance(vid_writer[i], cv2.VideoWriter):
179
+ vid_writer[i].release() # release previous video writer
180
+ if vid_cap: # video
181
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
182
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
183
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
184
+ else: # stream
185
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
186
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
187
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
188
+ vid_writer[i].write(im0)
189
+
190
+ # Print time (inference-only)
191
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
192
+
193
+ # Print results
194
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
195
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
196
+ if save_txt or save_img:
197
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
198
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
199
+ if update:
200
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
201
+
202
+
203
+ def parse_opt():
204
+ parser = argparse.ArgumentParser()
205
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo-seg.pt', help='model path(s)')
206
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
207
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
208
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
209
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
210
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
211
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
212
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
213
+ parser.add_argument('--view-img', action='store_true', help='show results')
214
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
215
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
216
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
217
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
218
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
219
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
220
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
221
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
222
+ parser.add_argument('--update', action='store_true', help='update all models')
223
+ parser.add_argument('--project', default=ROOT / 'runs/predict-seg', help='save results to project/name')
224
+ parser.add_argument('--name', default='exp', help='save results to project/name')
225
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
226
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
227
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
228
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
229
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
230
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
231
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
232
+ parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution')
233
+ opt = parser.parse_args()
234
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
235
+ print_args(vars(opt))
236
+ return opt
237
+
238
+
239
+ def main(opt):
240
+ check_requirements(exclude=('tensorboard', 'thop'))
241
+ run(**vars(opt))
242
+
243
+
244
+ if __name__ == "__main__":
245
+ opt = parse_opt()
246
+ main(opt)
yolov9/segment/train.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import yaml
16
+ from torch.optim import lr_scheduler
17
+ from tqdm import tqdm
18
+
19
+ FILE = Path(__file__).resolve()
20
+ ROOT = FILE.parents[1] # YOLO root directory
21
+ if str(ROOT) not in sys.path:
22
+ sys.path.append(str(ROOT)) # add ROOT to PATH
23
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
24
+
25
+ import segment.val as validate # for end-of-epoch mAP
26
+ from models.experimental import attempt_load
27
+ from models.yolo import SegmentationModel
28
+ from utils.autoanchor import check_anchors
29
+ from utils.autobatch import check_train_batch_size
30
+ from utils.callbacks import Callbacks
31
+ from utils.downloads import attempt_download, is_url
32
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
33
+ check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
34
+ get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
35
+ labels_to_image_weights, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
36
+ from utils.loggers import GenericLogger
37
+ from utils.plots import plot_evolve, plot_labels
38
+ from utils.segment.dataloaders import create_dataloader
39
+ from utils.segment.loss_tal import ComputeLoss
40
+ from utils.segment.metrics import KEYS, fitness
41
+ from utils.segment.plots import plot_images_and_masks, plot_results_with_masks
42
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
43
+ smart_resume, torch_distributed_zero_first)
44
+
45
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
46
+ RANK = int(os.getenv('RANK', -1))
47
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
48
+ GIT_INFO = None#check_git_info()
49
+
50
+
51
+ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
52
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, mask_ratio = \
53
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
54
+ opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.mask_ratio
55
+ # callbacks.run('on_pretrain_routine_start')
56
+
57
+ # Directories
58
+ w = save_dir / 'weights' # weights dir
59
+ (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
60
+ last, best = w / 'last.pt', w / 'best.pt'
61
+
62
+ # Hyperparameters
63
+ if isinstance(hyp, str):
64
+ with open(hyp, errors='ignore') as f:
65
+ hyp = yaml.safe_load(f) # load hyps dict
66
+ LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
67
+ opt.hyp = hyp.copy() # for saving hyps to checkpoints
68
+
69
+ # Save run settings
70
+ if not evolve:
71
+ yaml_save(save_dir / 'hyp.yaml', hyp)
72
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
73
+
74
+ # Loggers
75
+ data_dict = None
76
+ if RANK in {-1, 0}:
77
+ logger = GenericLogger(opt=opt, console_logger=LOGGER)
78
+
79
+ # Config
80
+ plots = not evolve and not opt.noplots # create plots
81
+ overlap = not opt.no_overlap
82
+ cuda = device.type != 'cpu'
83
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
84
+ with torch_distributed_zero_first(LOCAL_RANK):
85
+ data_dict = data_dict or check_dataset(data) # check if None
86
+ train_path, val_path = data_dict['train'], data_dict['val']
87
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
88
+ names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
89
+ #is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
90
+ is_coco = isinstance(val_path, str) and val_path.endswith('val2017.txt') # COCO dataset
91
+
92
+ # Model
93
+ check_suffix(weights, '.pt') # check weights
94
+ pretrained = weights.endswith('.pt')
95
+ if pretrained:
96
+ with torch_distributed_zero_first(LOCAL_RANK):
97
+ weights = attempt_download(weights) # download if not found locally
98
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
99
+ model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device)
100
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
101
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
102
+ csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
103
+ model.load_state_dict(csd, strict=False) # load
104
+ LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
105
+ else:
106
+ model = SegmentationModel(cfg, ch=3, nc=nc).to(device) # create
107
+ amp = check_amp(model) # check AMP
108
+
109
+ # Freeze
110
+ freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
111
+ for k, v in model.named_parameters():
112
+ #v.requires_grad = True # train all layers
113
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
114
+ if any(x in k for x in freeze):
115
+ LOGGER.info(f'freezing {k}')
116
+ v.requires_grad = False
117
+
118
+ # Image size
119
+ gs = max(int(model.stride.max()), 32) # grid size (max stride)
120
+ imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
121
+
122
+ # Batch size
123
+ if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
124
+ batch_size = check_train_batch_size(model, imgsz, amp)
125
+ logger.update_params({"batch_size": batch_size})
126
+ # loggers.on_params_update({"batch_size": batch_size})
127
+
128
+ # Optimizer
129
+ nbs = 64 # nominal batch size
130
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
131
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
132
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
133
+
134
+ # Scheduler
135
+ if opt.cos_lr:
136
+ lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
137
+ else:
138
+ lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
139
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
140
+
141
+ # EMA
142
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
143
+
144
+ # Resume
145
+ best_fitness, start_epoch = 0.0, 0
146
+ if pretrained:
147
+ if resume:
148
+ best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
149
+ del ckpt, csd
150
+
151
+ # DP mode
152
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
153
+ LOGGER.warning('WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.')
154
+ model = torch.nn.DataParallel(model)
155
+
156
+ # SyncBatchNorm
157
+ if opt.sync_bn and cuda and RANK != -1:
158
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
159
+ LOGGER.info('Using SyncBatchNorm()')
160
+
161
+ # Trainloader
162
+ train_loader, dataset = create_dataloader(
163
+ train_path,
164
+ imgsz,
165
+ batch_size // WORLD_SIZE,
166
+ gs,
167
+ single_cls,
168
+ hyp=hyp,
169
+ augment=True,
170
+ cache=None if opt.cache == 'val' else opt.cache,
171
+ rect=opt.rect,
172
+ rank=LOCAL_RANK,
173
+ workers=workers,
174
+ image_weights=opt.image_weights,
175
+ close_mosaic=opt.close_mosaic != 0,
176
+ quad=opt.quad,
177
+ prefix=colorstr('train: '),
178
+ shuffle=True,
179
+ mask_downsample_ratio=mask_ratio,
180
+ overlap_mask=overlap,
181
+ )
182
+ labels = np.concatenate(dataset.labels, 0)
183
+ mlc = int(labels[:, 0].max()) # max label class
184
+ assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
185
+
186
+ # Process 0
187
+ if RANK in {-1, 0}:
188
+ val_loader = create_dataloader(val_path,
189
+ imgsz,
190
+ batch_size // WORLD_SIZE * 2,
191
+ gs,
192
+ single_cls,
193
+ hyp=hyp,
194
+ cache=None if noval else opt.cache,
195
+ rect=True,
196
+ rank=-1,
197
+ workers=workers * 2,
198
+ pad=0.5,
199
+ mask_downsample_ratio=mask_ratio,
200
+ overlap_mask=overlap,
201
+ prefix=colorstr('val: '))[0]
202
+
203
+ if not resume:
204
+ #if not opt.noautoanchor:
205
+ # check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
206
+ model.half().float() # pre-reduce anchor precision
207
+
208
+ if plots:
209
+ plot_labels(labels, names, save_dir)
210
+ # callbacks.run('on_pretrain_routine_end', labels, names)
211
+
212
+ # DDP mode
213
+ if cuda and RANK != -1:
214
+ model = smart_DDP(model)
215
+
216
+ # Model attributes
217
+ nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
218
+ hyp['box'] *= 3 / nl # scale to layers
219
+ hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
220
+ hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
221
+ hyp['label_smoothing'] = opt.label_smoothing
222
+ model.nc = nc # attach number of classes to model
223
+ model.hyp = hyp # attach hyperparameters to model
224
+ model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
225
+ model.names = names
226
+
227
+ # Start training
228
+ t0 = time.time()
229
+ nb = len(train_loader) # number of batches
230
+ nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
231
+ # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
232
+ last_opt_step = -1
233
+ maps = np.zeros(nc) # mAP per class
234
+ results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
235
+ scheduler.last_epoch = start_epoch - 1 # do not move
236
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
237
+ stopper, stop = EarlyStopping(patience=opt.patience), False
238
+ compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
239
+ # callbacks.run('on_train_start')
240
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
241
+ f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
242
+ f"Logging results to {colorstr('bold', save_dir)}\n"
243
+ f'Starting training for {epochs} epochs...')
244
+ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
245
+ # callbacks.run('on_train_epoch_start')
246
+ model.train()
247
+
248
+ # Update image weights (optional, single-GPU only)
249
+ if opt.image_weights:
250
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
251
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
252
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
253
+ if epoch == (epochs - opt.close_mosaic):
254
+ LOGGER.info("Closing dataloader mosaic")
255
+ dataset.mosaic = False
256
+
257
+ # Update mosaic border (optional)
258
+ # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
259
+ # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
260
+
261
+ mloss = torch.zeros(4, device=device) # mean losses
262
+ if RANK != -1:
263
+ train_loader.sampler.set_epoch(epoch)
264
+ pbar = enumerate(train_loader)
265
+ LOGGER.info(('\n' + '%11s' * 8) %
266
+ ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss', 'Instances', 'Size'))
267
+ if RANK in {-1, 0}:
268
+ pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
269
+ optimizer.zero_grad()
270
+ for i, (imgs, targets, paths, _, masks) in pbar: # batch ------------------------------------------------------
271
+ # callbacks.run('on_train_batch_start')
272
+ ni = i + nb * epoch # number integrated batches (since train start)
273
+ imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
274
+
275
+ # Warmup
276
+ if ni <= nw:
277
+ xi = [0, nw] # x interp
278
+ # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
279
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
280
+ for j, x in enumerate(optimizer.param_groups):
281
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
282
+ x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
283
+ if 'momentum' in x:
284
+ x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
285
+
286
+ # Multi-scale
287
+ if opt.multi_scale:
288
+ sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
289
+ sf = sz / max(imgs.shape[2:]) # scale factor
290
+ if sf != 1:
291
+ ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
292
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
293
+
294
+ # Forward
295
+ with torch.cuda.amp.autocast(amp):
296
+ pred = model(imgs) # forward
297
+ loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
298
+ if RANK != -1:
299
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
300
+ if opt.quad:
301
+ loss *= 4.
302
+
303
+ # Backward
304
+ scaler.scale(loss).backward()
305
+
306
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
307
+ if ni - last_opt_step >= accumulate:
308
+ scaler.unscale_(optimizer) # unscale gradients
309
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
310
+ scaler.step(optimizer) # optimizer.step
311
+ scaler.update()
312
+ optimizer.zero_grad()
313
+ if ema:
314
+ ema.update(model)
315
+ last_opt_step = ni
316
+
317
+ # Log
318
+ if RANK in {-1, 0}:
319
+ mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
320
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
321
+ pbar.set_description(('%11s' * 2 + '%11.4g' * 6) %
322
+ (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
323
+ # callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
324
+ # if callbacks.stop_training:
325
+ # return
326
+
327
+ # Mosaic plots
328
+ if plots:
329
+ if ni < 3:
330
+ plot_images_and_masks(imgs, targets, masks, paths, save_dir / f"train_batch{ni}.jpg")
331
+ if ni == 10:
332
+ files = sorted(save_dir.glob('train*.jpg'))
333
+ logger.log_images(files, "Mosaics", epoch)
334
+ # end batch ------------------------------------------------------------------------------------------------
335
+
336
+ # Scheduler
337
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
338
+ scheduler.step()
339
+
340
+ if RANK in {-1, 0}:
341
+ # mAP
342
+ # callbacks.run('on_train_epoch_end', epoch=epoch)
343
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
344
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
345
+ if not noval or final_epoch: # Calculate mAP
346
+ results, maps, _ = validate.run(data_dict,
347
+ batch_size=batch_size // WORLD_SIZE * 2,
348
+ imgsz=imgsz,
349
+ half=amp,
350
+ model=ema.ema,
351
+ single_cls=single_cls,
352
+ dataloader=val_loader,
353
+ save_dir=save_dir,
354
+ plots=False,
355
+ callbacks=callbacks,
356
+ compute_loss=compute_loss,
357
+ mask_downsample_ratio=mask_ratio,
358
+ overlap=overlap)
359
+
360
+ # Update best mAP
361
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
362
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
363
+ if fi > best_fitness:
364
+ best_fitness = fi
365
+ log_vals = list(mloss) + list(results) + lr
366
+ # callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
367
+ # Log val metrics and media
368
+ metrics_dict = dict(zip(KEYS, log_vals))
369
+ logger.log_metrics(metrics_dict, epoch)
370
+
371
+ # Save model
372
+ if (not nosave) or (final_epoch and not evolve): # if save
373
+ ckpt = {
374
+ 'epoch': epoch,
375
+ 'best_fitness': best_fitness,
376
+ 'model': deepcopy(de_parallel(model)).half(),
377
+ 'ema': deepcopy(ema.ema).half(),
378
+ 'updates': ema.updates,
379
+ 'optimizer': optimizer.state_dict(),
380
+ 'opt': vars(opt),
381
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
382
+ 'date': datetime.now().isoformat()}
383
+
384
+ # Save last, best and delete
385
+ torch.save(ckpt, last)
386
+ if best_fitness == fi:
387
+ torch.save(ckpt, best)
388
+ if opt.save_period > 0 and epoch % opt.save_period == 0:
389
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
390
+ logger.log_model(w / f'epoch{epoch}.pt')
391
+ del ckpt
392
+ # callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
393
+
394
+ # EarlyStopping
395
+ if RANK != -1: # if DDP training
396
+ broadcast_list = [stop if RANK == 0 else None]
397
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
398
+ if RANK != 0:
399
+ stop = broadcast_list[0]
400
+ if stop:
401
+ break # must break all DDP ranks
402
+
403
+ # end epoch ----------------------------------------------------------------------------------------------------
404
+ # end training -----------------------------------------------------------------------------------------------------
405
+ if RANK in {-1, 0}:
406
+ LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
407
+ for f in last, best:
408
+ if f.exists():
409
+ strip_optimizer(f) # strip optimizers
410
+ if f is best:
411
+ LOGGER.info(f'\nValidating {f}...')
412
+ results, _, _ = validate.run(
413
+ data_dict,
414
+ batch_size=batch_size // WORLD_SIZE * 2,
415
+ imgsz=imgsz,
416
+ model=attempt_load(f, device).half(),
417
+ iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
418
+ single_cls=single_cls,
419
+ dataloader=val_loader,
420
+ save_dir=save_dir,
421
+ save_json=is_coco,
422
+ verbose=True,
423
+ plots=plots,
424
+ callbacks=callbacks,
425
+ compute_loss=compute_loss,
426
+ mask_downsample_ratio=mask_ratio,
427
+ overlap=overlap) # val best model with plots
428
+ if is_coco:
429
+ # callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
430
+ metrics_dict = dict(zip(KEYS, list(mloss) + list(results) + lr))
431
+ logger.log_metrics(metrics_dict, epoch)
432
+
433
+ # callbacks.run('on_train_end', last, best, epoch, results)
434
+ # on train end callback using genericLogger
435
+ logger.log_metrics(dict(zip(KEYS[4:16], results)), epochs)
436
+ if not opt.evolve:
437
+ logger.log_model(best, epoch)
438
+ if plots:
439
+ plot_results_with_masks(file=save_dir / 'results.csv') # save results.png
440
+ files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
441
+ files = [(save_dir / f) for f in files if (save_dir / f).exists()] # filter
442
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
443
+ logger.log_images(files, "Results", epoch + 1)
444
+ logger.log_images(sorted(save_dir.glob('val*.jpg')), "Validation", epoch + 1)
445
+ torch.cuda.empty_cache()
446
+ return results
447
+
448
+
449
+ def parse_opt(known=False):
450
+ parser = argparse.ArgumentParser()
451
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo-seg.pt', help='initial weights path')
452
+ parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
453
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path')
454
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
455
+ parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
456
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
457
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
458
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
459
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
460
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
461
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
462
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
463
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
464
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
465
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
466
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
467
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
468
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
469
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
470
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
471
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW', 'LION'], default='SGD', help='optimizer')
472
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
473
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
474
+ parser.add_argument('--project', default=ROOT / 'runs/train-seg', help='save to project/name')
475
+ parser.add_argument('--name', default='exp', help='save to project/name')
476
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
477
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
478
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
479
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
480
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
481
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
482
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
483
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
484
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
485
+ parser.add_argument('--close-mosaic', type=int, default=0, help='Experimental')
486
+
487
+ # Instance Segmentation Args
488
+ parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
489
+ parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')
490
+
491
+ return parser.parse_known_args()[0] if known else parser.parse_args()
492
+
493
+
494
+ def main(opt, callbacks=Callbacks()):
495
+ # Checks
496
+ if RANK in {-1, 0}:
497
+ print_args(vars(opt))
498
+ #check_git_status()
499
+ #check_requirements()
500
+
501
+ # Resume
502
+ if opt.resume and not opt.evolve: # resume from specified or most recent last.pt
503
+ last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
504
+ opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
505
+ opt_data = opt.data # original dataset
506
+ if opt_yaml.is_file():
507
+ with open(opt_yaml, errors='ignore') as f:
508
+ d = yaml.safe_load(f)
509
+ else:
510
+ d = torch.load(last, map_location='cpu')['opt']
511
+ opt = argparse.Namespace(**d) # replace
512
+ opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
513
+ if is_url(opt_data):
514
+ opt.data = check_file(opt_data) # avoid HUB resume auth timeout
515
+ else:
516
+ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
517
+ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
518
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
519
+ if opt.evolve:
520
+ if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
521
+ opt.project = str(ROOT / 'runs/evolve')
522
+ opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
523
+ if opt.name == 'cfg':
524
+ opt.name = Path(opt.cfg).stem # use model.yaml as name
525
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
526
+
527
+ # DDP mode
528
+ device = select_device(opt.device, batch_size=opt.batch_size)
529
+ if LOCAL_RANK != -1:
530
+ msg = 'is not compatible with YOLO Multi-GPU DDP training'
531
+ assert not opt.image_weights, f'--image-weights {msg}'
532
+ assert not opt.evolve, f'--evolve {msg}'
533
+ assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
534
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
535
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
536
+ torch.cuda.set_device(LOCAL_RANK)
537
+ device = torch.device('cuda', LOCAL_RANK)
538
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
539
+
540
+ # Train
541
+ if not opt.evolve:
542
+ train(opt.hyp, opt, device, callbacks)
543
+
544
+ # Evolve hyperparameters (optional)
545
+ else:
546
+ # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
547
+ meta = {
548
+ 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
549
+ 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
550
+ 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
551
+ 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
552
+ 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
553
+ 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
554
+ 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
555
+ 'box': (1, 0.02, 0.2), # box loss gain
556
+ 'cls': (1, 0.2, 4.0), # cls loss gain
557
+ 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
558
+ 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
559
+ 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
560
+ 'iou_t': (0, 0.1, 0.7), # IoU training threshold
561
+ 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
562
+ 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
563
+ 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
564
+ 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
565
+ 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
566
+ 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
567
+ 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
568
+ 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
569
+ 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
570
+ 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
571
+ 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
572
+ 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
573
+ 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
574
+ 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
575
+ 'mixup': (1, 0.0, 1.0), # image mixup (probability)
576
+ 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
577
+
578
+ with open(opt.hyp, errors='ignore') as f:
579
+ hyp = yaml.safe_load(f) # load hyps dict
580
+ if 'anchors' not in hyp: # anchors commented in hyp.yaml
581
+ hyp['anchors'] = 3
582
+ if opt.noautoanchor:
583
+ del hyp['anchors'], meta['anchors']
584
+ opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
585
+ # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
586
+ evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
587
+ if opt.bucket:
588
+ os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
589
+
590
+ for _ in range(opt.evolve): # generations to evolve
591
+ if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
592
+ # Select parent(s)
593
+ parent = 'single' # parent selection method: 'single' or 'weighted'
594
+ x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
595
+ n = min(5, len(x)) # number of previous results to consider
596
+ x = x[np.argsort(-fitness(x))][:n] # top n mutations
597
+ w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
598
+ if parent == 'single' or len(x) == 1:
599
+ # x = x[random.randint(0, n - 1)] # random selection
600
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
601
+ elif parent == 'weighted':
602
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
603
+
604
+ # Mutate
605
+ mp, s = 0.8, 0.2 # mutation probability, sigma
606
+ npr = np.random
607
+ npr.seed(int(time.time()))
608
+ g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
609
+ ng = len(meta)
610
+ v = np.ones(ng)
611
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
612
+ v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
613
+ for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
614
+ hyp[k] = float(x[i + 7] * v[i]) # mutate
615
+
616
+ # Constrain to limits
617
+ for k, v in meta.items():
618
+ hyp[k] = max(hyp[k], v[1]) # lower limit
619
+ hyp[k] = min(hyp[k], v[2]) # upper limit
620
+ hyp[k] = round(hyp[k], 5) # significant digits
621
+
622
+ # Train mutation
623
+ results = train(hyp.copy(), opt, device, callbacks)
624
+ callbacks = Callbacks()
625
+ # Write mutation results
626
+ print_mutation(KEYS, results, hyp.copy(), save_dir, opt.bucket)
627
+
628
+ # Plot results
629
+ plot_evolve(evolve_csv)
630
+ LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
631
+ f"Results saved to {colorstr('bold', save_dir)}\n"
632
+ f'Usage example: $ python train.py --hyp {evolve_yaml}')
633
+
634
+
635
+ def run(**kwargs):
636
+ # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolo.pt')
637
+ opt = parse_opt(True)
638
+ for k, v in kwargs.items():
639
+ setattr(opt, k, v)
640
+ main(opt)
641
+ return opt
642
+
643
+
644
+ if __name__ == "__main__":
645
+ opt = parse_opt()
646
+ main(opt)
yolov9/segment/val.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from multiprocessing.pool import ThreadPool
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ FILE = Path(__file__).resolve()
13
+ ROOT = FILE.parents[1] # YOLO root directory
14
+ if str(ROOT) not in sys.path:
15
+ sys.path.append(str(ROOT)) # add ROOT to PATH
16
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
17
+
18
+ import torch.nn.functional as F
19
+
20
+ from models.common import DetectMultiBackend
21
+ from models.yolo import SegmentationModel
22
+ from utils.callbacks import Callbacks
23
+ from utils.general import (LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, Profile, check_dataset, check_img_size,
24
+ check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path,
25
+ non_max_suppression, print_args, scale_boxes, xywh2xyxy, xyxy2xywh)
26
+ from utils.metrics import ConfusionMatrix, box_iou
27
+ from utils.plots import output_to_target, plot_val_study
28
+ from utils.segment.dataloaders import create_dataloader
29
+ from utils.segment.general import mask_iou, process_mask, process_mask_upsample, scale_image
30
+ from utils.segment.metrics import Metrics, ap_per_class_box_and_mask
31
+ from utils.segment.plots import plot_images_and_masks
32
+ from utils.torch_utils import de_parallel, select_device, smart_inference_mode
33
+
34
+
35
+ def save_one_txt(predn, save_conf, shape, file):
36
+ # Save one txt result
37
+ gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
38
+ for *xyxy, conf, cls in predn.tolist():
39
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
40
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
41
+ with open(file, 'a') as f:
42
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
43
+
44
+
45
+ def save_one_json(predn, jdict, path, class_map, pred_masks):
46
+ # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
47
+ from pycocotools.mask import encode
48
+
49
+ def single_encode(x):
50
+ rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
51
+ rle["counts"] = rle["counts"].decode("utf-8")
52
+ return rle
53
+
54
+ image_id = int(path.stem) if path.stem.isnumeric() else path.stem
55
+ box = xyxy2xywh(predn[:, :4]) # xywh
56
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
57
+ pred_masks = np.transpose(pred_masks, (2, 0, 1))
58
+ with ThreadPool(NUM_THREADS) as pool:
59
+ rles = pool.map(single_encode, pred_masks)
60
+ for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
61
+ jdict.append({
62
+ 'image_id': image_id,
63
+ 'category_id': class_map[int(p[5])],
64
+ 'bbox': [round(x, 3) for x in b],
65
+ 'score': round(p[4], 5),
66
+ 'segmentation': rles[i]})
67
+
68
+
69
+ def process_batch(detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
70
+ """
71
+ Return correct prediction matrix
72
+ Arguments:
73
+ detections (array[N, 6]), x1, y1, x2, y2, conf, class
74
+ labels (array[M, 5]), class, x1, y1, x2, y2
75
+ Returns:
76
+ correct (array[N, 10]), for 10 IoU levels
77
+ """
78
+ if masks:
79
+ if overlap:
80
+ nl = len(labels)
81
+ index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
82
+ gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
83
+ gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
84
+ if gt_masks.shape[1:] != pred_masks.shape[1:]:
85
+ gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
86
+ gt_masks = gt_masks.gt_(0.5)
87
+ iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
88
+ else: # boxes
89
+ iou = box_iou(labels[:, 1:], detections[:, :4])
90
+
91
+ correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
92
+ correct_class = labels[:, 0:1] == detections[:, 5]
93
+ for i in range(len(iouv)):
94
+ x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
95
+ if x[0].shape[0]:
96
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
97
+ if x[0].shape[0] > 1:
98
+ matches = matches[matches[:, 2].argsort()[::-1]]
99
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
100
+ # matches = matches[matches[:, 2].argsort()[::-1]]
101
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
102
+ correct[matches[:, 1].astype(int), i] = True
103
+ return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
104
+
105
+
106
+ @smart_inference_mode()
107
+ def run(
108
+ data,
109
+ weights=None, # model.pt path(s)
110
+ batch_size=32, # batch size
111
+ imgsz=640, # inference size (pixels)
112
+ conf_thres=0.001, # confidence threshold
113
+ iou_thres=0.6, # NMS IoU threshold
114
+ max_det=300, # maximum detections per image
115
+ task='val', # train, val, test, speed or study
116
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
117
+ workers=8, # max dataloader workers (per RANK in DDP mode)
118
+ single_cls=False, # treat as single-class dataset
119
+ augment=False, # augmented inference
120
+ verbose=False, # verbose output
121
+ save_txt=False, # save results to *.txt
122
+ save_hybrid=False, # save label+prediction hybrid results to *.txt
123
+ save_conf=False, # save confidences in --save-txt labels
124
+ save_json=False, # save a COCO-JSON results file
125
+ project=ROOT / 'runs/val-seg', # save to project/name
126
+ name='exp', # save to project/name
127
+ exist_ok=False, # existing project/name ok, do not increment
128
+ half=True, # use FP16 half-precision inference
129
+ dnn=False, # use OpenCV DNN for ONNX inference
130
+ model=None,
131
+ dataloader=None,
132
+ save_dir=Path(''),
133
+ plots=True,
134
+ overlap=False,
135
+ mask_downsample_ratio=1,
136
+ compute_loss=None,
137
+ callbacks=Callbacks(),
138
+ ):
139
+ if save_json:
140
+ check_requirements(['pycocotools'])
141
+ process = process_mask_upsample # more accurate
142
+ else:
143
+ process = process_mask # faster
144
+
145
+ # Initialize/load model and set device
146
+ training = model is not None
147
+ if training: # called by train.py
148
+ device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
149
+ half &= device.type != 'cpu' # half precision only supported on CUDA
150
+ model.half() if half else model.float()
151
+ nm = de_parallel(model).model[-1].nm # number of masks
152
+ else: # called directly
153
+ device = select_device(device, batch_size=batch_size)
154
+
155
+ # Directories
156
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
157
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
158
+
159
+ # Load model
160
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
161
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
162
+ imgsz = check_img_size(imgsz, s=stride) # check image size
163
+ half = model.fp16 # FP16 supported on limited backends with CUDA
164
+ nm = de_parallel(model).model.model[-1].nm if isinstance(model, SegmentationModel) else 32 # number of masks
165
+ if engine:
166
+ batch_size = model.batch_size
167
+ else:
168
+ device = model.device
169
+ if not (pt or jit):
170
+ batch_size = 1 # export.py models default to batch-size 1
171
+ LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
172
+
173
+ # Data
174
+ data = check_dataset(data) # check
175
+
176
+ # Configure
177
+ model.eval()
178
+ cuda = device.type != 'cpu'
179
+ #is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'coco{os.sep}val2017.txt') # COCO dataset
180
+ is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'val2017.txt') # COCO dataset
181
+ nc = 1 if single_cls else int(data['nc']) # number of classes
182
+ iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95
183
+ niou = iouv.numel()
184
+
185
+ # Dataloader
186
+ if not training:
187
+ if pt and not single_cls: # check --weights are trained on --data
188
+ ncm = model.model.nc
189
+ assert ncm == nc, f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' \
190
+ f'classes). Pass correct combination of --weights and --data that are trained together.'
191
+ model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
192
+ pad, rect = (0.0, False) if task == 'speed' else (0.5, pt) # square inference for benchmarks
193
+ task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
194
+ dataloader = create_dataloader(data[task],
195
+ imgsz,
196
+ batch_size,
197
+ stride,
198
+ single_cls,
199
+ pad=pad,
200
+ rect=rect,
201
+ workers=workers,
202
+ prefix=colorstr(f'{task}: '),
203
+ overlap_mask=overlap,
204
+ mask_downsample_ratio=mask_downsample_ratio)[0]
205
+
206
+ seen = 0
207
+ confusion_matrix = ConfusionMatrix(nc=nc)
208
+ names = model.names if hasattr(model, 'names') else model.module.names # get class names
209
+ if isinstance(names, (list, tuple)): # old format
210
+ names = dict(enumerate(names))
211
+ class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
212
+ s = ('%22s' + '%11s' * 10) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", "R",
213
+ "mAP50", "mAP50-95)")
214
+ dt = Profile(), Profile(), Profile()
215
+ metrics = Metrics()
216
+ loss = torch.zeros(4, device=device)
217
+ jdict, stats = [], []
218
+ # callbacks.run('on_val_start')
219
+ pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar
220
+ for batch_i, (im, targets, paths, shapes, masks) in enumerate(pbar):
221
+ # callbacks.run('on_val_batch_start')
222
+ with dt[0]:
223
+ if cuda:
224
+ im = im.to(device, non_blocking=True)
225
+ targets = targets.to(device)
226
+ masks = masks.to(device)
227
+ masks = masks.float()
228
+ im = im.half() if half else im.float() # uint8 to fp16/32
229
+ im /= 255 # 0 - 255 to 0.0 - 1.0
230
+ nb, _, height, width = im.shape # batch size, channels, height, width
231
+
232
+ # Inference
233
+ with dt[1]:
234
+ preds, train_out = model(im)# if compute_loss else (*model(im, augment=augment)[:2], None)
235
+ #train_out, preds, protos = p if len(p) == 3 else p[1]
236
+ #preds = p
237
+ #train_out = p[1][0] if len(p[1]) == 3 else p[0]
238
+ protos = train_out[-1]
239
+ #print(preds.shape)
240
+ #print(train_out[0].shape)
241
+ #print(train_out[1].shape)
242
+ #print(train_out[2].shape)
243
+
244
+ # Loss
245
+ if compute_loss:
246
+ loss += compute_loss(train_out, targets, masks)[1] # box, obj, cls
247
+
248
+ # NMS
249
+ targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
250
+ lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
251
+ with dt[2]:
252
+ preds = non_max_suppression(preds,
253
+ conf_thres,
254
+ iou_thres,
255
+ labels=lb,
256
+ multi_label=True,
257
+ agnostic=single_cls,
258
+ max_det=max_det,
259
+ nm=nm)
260
+
261
+ # Metrics
262
+ plot_masks = [] # masks for plotting
263
+ for si, (pred, proto) in enumerate(zip(preds, protos)):
264
+ labels = targets[targets[:, 0] == si, 1:]
265
+ nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
266
+ path, shape = Path(paths[si]), shapes[si][0]
267
+ correct_masks = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
268
+ correct_bboxes = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
269
+ seen += 1
270
+
271
+ if npr == 0:
272
+ if nl:
273
+ stats.append((correct_masks, correct_bboxes, *torch.zeros((2, 0), device=device), labels[:, 0]))
274
+ if plots:
275
+ confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
276
+ continue
277
+
278
+ # Masks
279
+ midx = [si] if overlap else targets[:, 0] == si
280
+ gt_masks = masks[midx]
281
+ pred_masks = process(proto, pred[:, 6:], pred[:, :4], shape=im[si].shape[1:])
282
+
283
+ # Predictions
284
+ if single_cls:
285
+ pred[:, 5] = 0
286
+ predn = pred.clone()
287
+ scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
288
+
289
+ # Evaluate
290
+ if nl:
291
+ tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
292
+ scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
293
+ labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
294
+ correct_bboxes = process_batch(predn, labelsn, iouv)
295
+ correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True)
296
+ if plots:
297
+ confusion_matrix.process_batch(predn, labelsn)
298
+ stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
299
+
300
+ pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
301
+ if plots and batch_i < 3:
302
+ plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
303
+
304
+ # Save/log
305
+ if save_txt:
306
+ save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
307
+ if save_json:
308
+ pred_masks = scale_image(im[si].shape[1:],
309
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
310
+ save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary
311
+ # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
312
+
313
+ # Plot images
314
+ if plots and batch_i < 3:
315
+ if len(plot_masks):
316
+ plot_masks = torch.cat(plot_masks, dim=0)
317
+ plot_images_and_masks(im, targets, masks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
318
+ plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, paths,
319
+ save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
320
+
321
+ # callbacks.run('on_val_batch_end')
322
+
323
+ # Compute metrics
324
+ stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
325
+ if len(stats) and stats[0].any():
326
+ results = ap_per_class_box_and_mask(*stats, plot=plots, save_dir=save_dir, names=names)
327
+ metrics.update(results)
328
+ nt = np.bincount(stats[4].astype(int), minlength=nc) # number of targets per class
329
+
330
+ # Print results
331
+ pf = '%22s' + '%11i' * 2 + '%11.3g' * 8 # print format
332
+ LOGGER.info(pf % ("all", seen, nt.sum(), *metrics.mean_results()))
333
+ if nt.sum() == 0:
334
+ LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')
335
+
336
+ # Print results per class
337
+ if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
338
+ for i, c in enumerate(metrics.ap_class_index):
339
+ LOGGER.info(pf % (names[c], seen, nt[c], *metrics.class_result(i)))
340
+
341
+ # Print speeds
342
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
343
+ if not training:
344
+ shape = (batch_size, 3, imgsz, imgsz)
345
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
346
+
347
+ # Plots
348
+ if plots:
349
+ confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
350
+ # callbacks.run('on_val_end')
351
+
352
+ mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask = metrics.mean_results()
353
+
354
+ # Save JSON
355
+ if save_json and len(jdict):
356
+ w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
357
+ anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
358
+ pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
359
+ LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
360
+ with open(pred_json, 'w') as f:
361
+ json.dump(jdict, f)
362
+
363
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
364
+ from pycocotools.coco import COCO
365
+ from pycocotools.cocoeval import COCOeval
366
+
367
+ anno = COCO(anno_json) # init annotations api
368
+ pred = anno.loadRes(pred_json) # init predictions api
369
+ results = []
370
+ for eval in COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm'):
371
+ if is_coco:
372
+ eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # img ID to evaluate
373
+ eval.evaluate()
374
+ eval.accumulate()
375
+ eval.summarize()
376
+ results.extend(eval.stats[:2]) # update results (mAP@0.5:0.95, mAP@0.5)
377
+ map_bbox, map50_bbox, map_mask, map50_mask = results
378
+ except Exception as e:
379
+ LOGGER.info(f'pycocotools unable to run: {e}')
380
+
381
+ # Return results
382
+ model.float() # for training
383
+ if not training:
384
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
385
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
386
+ final_metric = mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask
387
+ return (*final_metric, *(loss.cpu() / len(dataloader)).tolist()), metrics.get_maps(nc), t
388
+
389
+
390
+ def parse_opt():
391
+ parser = argparse.ArgumentParser()
392
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path')
393
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo-seg.pt', help='model path(s)')
394
+ parser.add_argument('--batch-size', type=int, default=32, help='batch size')
395
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
396
+ parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
397
+ parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold')
398
+ parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image')
399
+ parser.add_argument('--task', default='val', help='train, val, test, speed or study')
400
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
401
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
402
+ parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
403
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
404
+ parser.add_argument('--verbose', action='store_true', help='report mAP by class')
405
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
406
+ parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
407
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
408
+ parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file')
409
+ parser.add_argument('--project', default=ROOT / 'runs/val-seg', help='save results to project/name')
410
+ parser.add_argument('--name', default='exp', help='save to project/name')
411
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
412
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
413
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
414
+ opt = parser.parse_args()
415
+ opt.data = check_yaml(opt.data) # check YAML
416
+ # opt.save_json |= opt.data.endswith('coco.yaml')
417
+ opt.save_txt |= opt.save_hybrid
418
+ print_args(vars(opt))
419
+ return opt
420
+
421
+
422
+ def main(opt):
423
+ #check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
424
+
425
+ if opt.task in ('train', 'val', 'test'): # run normally
426
+ if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
427
+ LOGGER.warning(f'WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results')
428
+ if opt.save_hybrid:
429
+ LOGGER.warning('WARNING ⚠️ --save-hybrid returns high mAP from hybrid labels, not from predictions alone')
430
+ run(**vars(opt))
431
+
432
+ else:
433
+ weights = opt.weights if isinstance(opt.weights, list) else [opt.weights]
434
+ opt.half = torch.cuda.is_available() and opt.device != 'cpu' # FP16 for fastest results
435
+ if opt.task == 'speed': # speed benchmarks
436
+ # python val.py --task speed --data coco.yaml --batch 1 --weights yolo.pt...
437
+ opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False
438
+ for opt.weights in weights:
439
+ run(**vars(opt), plots=False)
440
+
441
+ elif opt.task == 'study': # speed vs mAP benchmarks
442
+ # python val.py --task study --data coco.yaml --iou 0.7 --weights yolo.pt...
443
+ for opt.weights in weights:
444
+ f = f'study_{Path(opt.data).stem}_{Path(opt.weights).stem}.txt' # filename to save to
445
+ x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis
446
+ for opt.imgsz in x: # img-size
447
+ LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...')
448
+ r, _, t = run(**vars(opt), plots=False)
449
+ y.append(r + t) # results and times
450
+ np.savetxt(f, y, fmt='%10.4g') # save
451
+ os.system('zip -r study.zip study_*.txt')
452
+ plot_val_study(x=x) # plot
453
+
454
+
455
+ if __name__ == "__main__":
456
+ opt = parse_opt()
457
+ main(opt)
yolov9/train.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import yaml
16
+ from torch.optim import lr_scheduler
17
+ from tqdm import tqdm
18
+
19
+ FILE = Path(__file__).resolve()
20
+ ROOT = FILE.parents[0] # root directory
21
+ if str(ROOT) not in sys.path:
22
+ sys.path.append(str(ROOT)) # add ROOT to PATH
23
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
24
+
25
+ import val as validate # for end-of-epoch mAP
26
+ from models.experimental import attempt_load
27
+ from models.yolo import Model
28
+ from utils.autoanchor import check_anchors
29
+ from utils.autobatch import check_train_batch_size
30
+ from utils.callbacks import Callbacks
31
+ from utils.dataloaders import create_dataloader
32
+ from utils.downloads import attempt_download, is_url
33
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_img_size,
34
+ check_suffix, check_yaml, colorstr, get_latest_run, increment_path, init_seeds,
35
+ intersect_dicts, labels_to_class_weights, labels_to_image_weights, methods,
36
+ one_cycle, one_flat_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
37
+ from utils.loggers import Loggers
38
+ from utils.loggers.comet.comet_utils import check_comet_resume
39
+ from utils.loss_tal import ComputeLoss
40
+ from utils.metrics import fitness
41
+ from utils.plots import plot_evolve
42
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP,
43
+ smart_optimizer, smart_resume, torch_distributed_zero_first)
44
+
45
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
46
+ RANK = int(os.getenv('RANK', -1))
47
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
48
+ GIT_INFO = None
49
+
50
+
51
+ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
52
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
53
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
54
+ opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
55
+ callbacks.run('on_pretrain_routine_start')
56
+
57
+ # Directories
58
+ w = save_dir / 'weights' # weights dir
59
+ (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
60
+ last, best = w / 'last.pt', w / 'best.pt'
61
+ last_striped, best_striped = w / 'last_striped.pt', w / 'best_striped.pt'
62
+
63
+ # Hyperparameters
64
+ if isinstance(hyp, str):
65
+ with open(hyp, errors='ignore') as f:
66
+ hyp = yaml.safe_load(f) # load hyps dict
67
+ LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
68
+ hyp['anchor_t'] = 5.0
69
+ opt.hyp = hyp.copy() # for saving hyps to checkpoints
70
+
71
+ # Save run settings
72
+ if not evolve:
73
+ yaml_save(save_dir / 'hyp.yaml', hyp)
74
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
75
+
76
+ # Loggers
77
+ data_dict = None
78
+ if RANK in {-1, 0}:
79
+ loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
80
+
81
+ # Register actions
82
+ for k in methods(loggers):
83
+ callbacks.register_action(k, callback=getattr(loggers, k))
84
+
85
+ # Process custom dataset artifact link
86
+ data_dict = loggers.remote_dataset
87
+ if resume: # If resuming runs from remote artifact
88
+ weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
89
+
90
+ # Config
91
+ plots = not evolve and not opt.noplots # create plots
92
+ cuda = device.type != 'cpu'
93
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
94
+ with torch_distributed_zero_first(LOCAL_RANK):
95
+ data_dict = data_dict or check_dataset(data) # check if None
96
+ train_path, val_path = data_dict['train'], data_dict['val']
97
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
98
+ names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
99
+ #is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
100
+ is_coco = isinstance(val_path, str) and val_path.endswith('val2017.txt') # COCO dataset
101
+
102
+ # Model
103
+ check_suffix(weights, '.pt') # check weights
104
+ pretrained = weights.endswith('.pt')
105
+ if pretrained:
106
+ with torch_distributed_zero_first(LOCAL_RANK):
107
+ weights = attempt_download(weights) # download if not found locally
108
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
109
+ model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
110
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
111
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
112
+ csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
113
+ model.load_state_dict(csd, strict=False) # load
114
+ LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
115
+ else:
116
+ model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
117
+ amp = check_amp(model) # check AMP
118
+
119
+ # Freeze
120
+ freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
121
+ for k, v in model.named_parameters():
122
+ # v.requires_grad = True # train all layers TODO: uncomment this line as in master
123
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
124
+ if any(x in k for x in freeze):
125
+ LOGGER.info(f'freezing {k}')
126
+ v.requires_grad = False
127
+
128
+ # Image size
129
+ gs = max(int(model.stride.max()), 32) # grid size (max stride)
130
+ imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
131
+
132
+ # Batch size
133
+ if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
134
+ batch_size = check_train_batch_size(model, imgsz, amp)
135
+ loggers.on_params_update({"batch_size": batch_size})
136
+
137
+ # Optimizer
138
+ nbs = 64 # nominal batch size
139
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
140
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
141
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
142
+
143
+ # Scheduler
144
+ if opt.cos_lr:
145
+ lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
146
+ elif opt.flat_cos_lr:
147
+ lf = one_flat_cycle(1, hyp['lrf'], epochs) # flat cosine 1->hyp['lrf']
148
+ elif opt.fixed_lr:
149
+ lf = lambda x: 1.0
150
+ else:
151
+ lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
152
+
153
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
154
+ # from utils.plots import plot_lr_scheduler; plot_lr_scheduler(optimizer, scheduler, epochs)
155
+
156
+ # EMA
157
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
158
+
159
+ # Resume
160
+ best_fitness, start_epoch = 0.0, 0
161
+ if pretrained:
162
+ if resume:
163
+ best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
164
+ del ckpt, csd
165
+
166
+ # DP mode
167
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
168
+ LOGGER.warning('WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.')
169
+ model = torch.nn.DataParallel(model)
170
+
171
+ # SyncBatchNorm
172
+ if opt.sync_bn and cuda and RANK != -1:
173
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
174
+ LOGGER.info('Using SyncBatchNorm()')
175
+
176
+ # Trainloader
177
+ train_loader, dataset = create_dataloader(train_path,
178
+ imgsz,
179
+ batch_size // WORLD_SIZE,
180
+ gs,
181
+ single_cls,
182
+ hyp=hyp,
183
+ augment=True,
184
+ cache=None if opt.cache == 'val' else opt.cache,
185
+ rect=opt.rect,
186
+ rank=LOCAL_RANK,
187
+ workers=workers,
188
+ image_weights=opt.image_weights,
189
+ close_mosaic=opt.close_mosaic != 0,
190
+ quad=opt.quad,
191
+ prefix=colorstr('train: '),
192
+ shuffle=True,
193
+ min_items=opt.min_items)
194
+ labels = np.concatenate(dataset.labels, 0)
195
+ mlc = int(labels[:, 0].max()) # max label class
196
+ assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
197
+
198
+ # Process 0
199
+ if RANK in {-1, 0}:
200
+ val_loader = create_dataloader(val_path,
201
+ imgsz,
202
+ batch_size // WORLD_SIZE * 2,
203
+ gs,
204
+ single_cls,
205
+ hyp=hyp,
206
+ cache=None if noval else opt.cache,
207
+ rect=True,
208
+ rank=-1,
209
+ workers=workers * 2,
210
+ pad=0.5,
211
+ prefix=colorstr('val: '))[0]
212
+
213
+ if not resume:
214
+ # if not opt.noautoanchor:
215
+ # check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
216
+ model.half().float() # pre-reduce anchor precision
217
+
218
+ callbacks.run('on_pretrain_routine_end', labels, names)
219
+
220
+ # DDP mode
221
+ if cuda and RANK != -1:
222
+ model = smart_DDP(model)
223
+
224
+ # Model attributes
225
+ nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
226
+ #hyp['box'] *= 3 / nl # scale to layers
227
+ #hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
228
+ #hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
229
+ hyp['label_smoothing'] = opt.label_smoothing
230
+ model.nc = nc # attach number of classes to model
231
+ model.hyp = hyp # attach hyperparameters to model
232
+ model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
233
+ model.names = names
234
+
235
+ # Start training
236
+ t0 = time.time()
237
+ nb = len(train_loader) # number of batches
238
+ nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
239
+ # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
240
+ last_opt_step = -1
241
+ maps = np.zeros(nc) # mAP per class
242
+ results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
243
+ scheduler.last_epoch = start_epoch - 1 # do not move
244
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
245
+ stopper, stop = EarlyStopping(patience=opt.patience), False
246
+ compute_loss = ComputeLoss(model) # init loss class
247
+ callbacks.run('on_train_start')
248
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
249
+ f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
250
+ f"Logging results to {colorstr('bold', save_dir)}\n"
251
+ f'Starting training for {epochs} epochs...')
252
+ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
253
+ callbacks.run('on_train_epoch_start')
254
+ model.train()
255
+
256
+ # Update image weights (optional, single-GPU only)
257
+ if opt.image_weights:
258
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
259
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
260
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
261
+ if epoch == (epochs - opt.close_mosaic):
262
+ LOGGER.info("Closing dataloader mosaic")
263
+ dataset.mosaic = False
264
+
265
+ # Update mosaic border (optional)
266
+ # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
267
+ # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
268
+
269
+ mloss = torch.zeros(3, device=device) # mean losses
270
+ if RANK != -1:
271
+ train_loader.sampler.set_epoch(epoch)
272
+ pbar = enumerate(train_loader)
273
+ LOGGER.info(('\n' + '%11s' * 7) % ('Epoch', 'GPU_mem', 'box_loss', 'cls_loss', 'dfl_loss', 'Instances', 'Size'))
274
+ if RANK in {-1, 0}:
275
+ pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
276
+ optimizer.zero_grad()
277
+ for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
278
+ callbacks.run('on_train_batch_start')
279
+ ni = i + nb * epoch # number integrated batches (since train start)
280
+ imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
281
+
282
+ # Warmup
283
+ if ni <= nw:
284
+ xi = [0, nw] # x interp
285
+ # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
286
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
287
+ for j, x in enumerate(optimizer.param_groups):
288
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
289
+ x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
290
+ if 'momentum' in x:
291
+ x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
292
+
293
+ # Multi-scale
294
+ if opt.multi_scale:
295
+ sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
296
+ sf = sz / max(imgs.shape[2:]) # scale factor
297
+ if sf != 1:
298
+ ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
299
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
300
+
301
+ # Forward
302
+ with torch.cuda.amp.autocast(amp):
303
+ pred = model(imgs) # forward
304
+ loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
305
+ if RANK != -1:
306
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
307
+ if opt.quad:
308
+ loss *= 4.
309
+
310
+ # Backward
311
+ scaler.scale(loss).backward()
312
+
313
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
314
+ if ni - last_opt_step >= accumulate:
315
+ scaler.unscale_(optimizer) # unscale gradients
316
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
317
+ scaler.step(optimizer) # optimizer.step
318
+ scaler.update()
319
+ optimizer.zero_grad()
320
+ if ema:
321
+ ema.update(model)
322
+ last_opt_step = ni
323
+
324
+ # Log
325
+ if RANK in {-1, 0}:
326
+ mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
327
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
328
+ pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
329
+ (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
330
+ callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths, list(mloss))
331
+ if callbacks.stop_training:
332
+ return
333
+ # end batch ------------------------------------------------------------------------------------------------
334
+
335
+ # Scheduler
336
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
337
+ scheduler.step()
338
+
339
+ if RANK in {-1, 0}:
340
+ # mAP
341
+ callbacks.run('on_train_epoch_end', epoch=epoch)
342
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
343
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
344
+ if not noval or final_epoch: # Calculate mAP
345
+ results, maps, _ = validate.run(data_dict,
346
+ batch_size=batch_size // WORLD_SIZE * 2,
347
+ imgsz=imgsz,
348
+ half=amp,
349
+ model=ema.ema,
350
+ single_cls=single_cls,
351
+ dataloader=val_loader,
352
+ save_dir=save_dir,
353
+ plots=False,
354
+ callbacks=callbacks,
355
+ compute_loss=compute_loss)
356
+
357
+ # Update best mAP
358
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
359
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
360
+ if fi > best_fitness:
361
+ best_fitness = fi
362
+ log_vals = list(mloss) + list(results) + lr
363
+ callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
364
+
365
+ # Save model
366
+ if (not nosave) or (final_epoch and not evolve): # if save
367
+ ckpt = {
368
+ 'epoch': epoch,
369
+ 'best_fitness': best_fitness,
370
+ 'model': deepcopy(de_parallel(model)).half(),
371
+ 'ema': deepcopy(ema.ema).half(),
372
+ 'updates': ema.updates,
373
+ 'optimizer': optimizer.state_dict(),
374
+ 'opt': vars(opt),
375
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
376
+ 'date': datetime.now().isoformat()}
377
+
378
+ # Save last, best and delete
379
+ torch.save(ckpt, last)
380
+ if best_fitness == fi:
381
+ torch.save(ckpt, best)
382
+ if opt.save_period > 0 and epoch % opt.save_period == 0:
383
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
384
+ del ckpt
385
+ callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
386
+
387
+ # EarlyStopping
388
+ if RANK != -1: # if DDP training
389
+ broadcast_list = [stop if RANK == 0 else None]
390
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
391
+ if RANK != 0:
392
+ stop = broadcast_list[0]
393
+ if stop:
394
+ break # must break all DDP ranks
395
+
396
+ # end epoch ----------------------------------------------------------------------------------------------------
397
+ # end training -----------------------------------------------------------------------------------------------------
398
+ if RANK in {-1, 0}:
399
+ LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
400
+ for f in last, best:
401
+ if f.exists():
402
+ if f is last:
403
+ strip_optimizer(f, last_striped) # strip optimizers
404
+ else:
405
+ strip_optimizer(f, best_striped) # strip optimizers
406
+ if f is best:
407
+ LOGGER.info(f'\nValidating {f}...')
408
+ results, _, _ = validate.run(
409
+ data_dict,
410
+ batch_size=batch_size // WORLD_SIZE * 2,
411
+ imgsz=imgsz,
412
+ model=attempt_load(f, device).half(),
413
+ single_cls=single_cls,
414
+ dataloader=val_loader,
415
+ save_dir=save_dir,
416
+ save_json=is_coco,
417
+ verbose=True,
418
+ plots=plots,
419
+ callbacks=callbacks,
420
+ compute_loss=compute_loss) # val best model with plots
421
+ if is_coco:
422
+ callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
423
+
424
+ callbacks.run('on_train_end', last, best, epoch, results)
425
+
426
+ torch.cuda.empty_cache()
427
+ return results
428
+
429
+
430
+ def parse_opt(known=False):
431
+ parser = argparse.ArgumentParser()
432
+ # parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='initial weights path')
433
+ # parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
434
+ parser.add_argument('--weights', type=str, default='', help='initial weights path')
435
+ parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml path')
436
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
437
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
438
+ parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
439
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
440
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
441
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
442
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
443
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
444
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
445
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
446
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
447
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
448
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
449
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
450
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
451
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
452
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
453
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
454
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW', 'LION'], default='SGD', help='optimizer')
455
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
456
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
457
+ parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
458
+ parser.add_argument('--name', default='exp', help='save to project/name')
459
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
460
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
461
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
462
+ parser.add_argument('--flat-cos-lr', action='store_true', help='flat cosine LR scheduler')
463
+ parser.add_argument('--fixed-lr', action='store_true', help='fixed LR scheduler')
464
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
465
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
466
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
467
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
468
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
469
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
470
+ parser.add_argument('--min-items', type=int, default=0, help='Experimental')
471
+ parser.add_argument('--close-mosaic', type=int, default=0, help='Experimental')
472
+
473
+ # Logger arguments
474
+ parser.add_argument('--entity', default=None, help='Entity')
475
+ parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option')
476
+ parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval')
477
+ parser.add_argument('--artifact_alias', type=str, default='latest', help='Version of dataset artifact to use')
478
+
479
+ return parser.parse_known_args()[0] if known else parser.parse_args()
480
+
481
+
482
+ def main(opt, callbacks=Callbacks()):
483
+ # Checks
484
+ if RANK in {-1, 0}:
485
+ print_args(vars(opt))
486
+
487
+ # Resume (from specified or most recent last.pt)
488
+ if opt.resume and not check_comet_resume(opt) and not opt.evolve:
489
+ last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
490
+ opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
491
+ opt_data = opt.data # original dataset
492
+ if opt_yaml.is_file():
493
+ with open(opt_yaml, errors='ignore') as f:
494
+ d = yaml.safe_load(f)
495
+ else:
496
+ d = torch.load(last, map_location='cpu')['opt']
497
+ opt = argparse.Namespace(**d) # replace
498
+ opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
499
+ if is_url(opt_data):
500
+ opt.data = check_file(opt_data) # avoid HUB resume auth timeout
501
+ else:
502
+ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
503
+ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
504
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
505
+ if opt.evolve:
506
+ if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
507
+ opt.project = str(ROOT / 'runs/evolve')
508
+ opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
509
+ if opt.name == 'cfg':
510
+ opt.name = Path(opt.cfg).stem # use model.yaml as name
511
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
512
+
513
+ # DDP mode
514
+ device = select_device(opt.device, batch_size=opt.batch_size)
515
+ if LOCAL_RANK != -1:
516
+ msg = 'is not compatible with YOLO Multi-GPU DDP training'
517
+ assert not opt.image_weights, f'--image-weights {msg}'
518
+ assert not opt.evolve, f'--evolve {msg}'
519
+ assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
520
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
521
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
522
+ torch.cuda.set_device(LOCAL_RANK)
523
+ device = torch.device('cuda', LOCAL_RANK)
524
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
525
+
526
+ # Train
527
+ if not opt.evolve:
528
+ train(opt.hyp, opt, device, callbacks)
529
+
530
+ # Evolve hyperparameters (optional)
531
+ else:
532
+ # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
533
+ meta = {
534
+ 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
535
+ 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
536
+ 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
537
+ 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
538
+ 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
539
+ 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
540
+ 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
541
+ 'box': (1, 0.02, 0.2), # box loss gain
542
+ 'cls': (1, 0.2, 4.0), # cls loss gain
543
+ 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
544
+ 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
545
+ 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
546
+ 'iou_t': (0, 0.1, 0.7), # IoU training threshold
547
+ 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
548
+ 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
549
+ 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
550
+ 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
551
+ 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
552
+ 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
553
+ 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
554
+ 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
555
+ 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
556
+ 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
557
+ 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
558
+ 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
559
+ 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
560
+ 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
561
+ 'mixup': (1, 0.0, 1.0), # image mixup (probability)
562
+ 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
563
+
564
+ with open(opt.hyp, errors='ignore') as f:
565
+ hyp = yaml.safe_load(f) # load hyps dict
566
+ if 'anchors' not in hyp: # anchors commented in hyp.yaml
567
+ hyp['anchors'] = 3
568
+ if opt.noautoanchor:
569
+ del hyp['anchors'], meta['anchors']
570
+ opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
571
+ # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
572
+ evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
573
+ if opt.bucket:
574
+ os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
575
+
576
+ for _ in range(opt.evolve): # generations to evolve
577
+ if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
578
+ # Select parent(s)
579
+ parent = 'single' # parent selection method: 'single' or 'weighted'
580
+ x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
581
+ n = min(5, len(x)) # number of previous results to consider
582
+ x = x[np.argsort(-fitness(x))][:n] # top n mutations
583
+ w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
584
+ if parent == 'single' or len(x) == 1:
585
+ # x = x[random.randint(0, n - 1)] # random selection
586
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
587
+ elif parent == 'weighted':
588
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
589
+
590
+ # Mutate
591
+ mp, s = 0.8, 0.2 # mutation probability, sigma
592
+ npr = np.random
593
+ npr.seed(int(time.time()))
594
+ g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
595
+ ng = len(meta)
596
+ v = np.ones(ng)
597
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
598
+ v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
599
+ for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
600
+ hyp[k] = float(x[i + 7] * v[i]) # mutate
601
+
602
+ # Constrain to limits
603
+ for k, v in meta.items():
604
+ hyp[k] = max(hyp[k], v[1]) # lower limit
605
+ hyp[k] = min(hyp[k], v[2]) # upper limit
606
+ hyp[k] = round(hyp[k], 5) # significant digits
607
+
608
+ # Train mutation
609
+ results = train(hyp.copy(), opt, device, callbacks)
610
+ callbacks = Callbacks()
611
+ # Write mutation results
612
+ keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
613
+ 'val/obj_loss', 'val/cls_loss')
614
+ print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)
615
+
616
+ # Plot results
617
+ plot_evolve(evolve_csv)
618
+ LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
619
+ f"Results saved to {colorstr('bold', save_dir)}\n"
620
+ f'Usage example: $ python train.py --hyp {evolve_yaml}')
621
+
622
+
623
+ def run(**kwargs):
624
+ # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolo.pt')
625
+ opt = parse_opt(True)
626
+ for k, v in kwargs.items():
627
+ setattr(opt, k, v)
628
+ main(opt)
629
+ return opt
630
+
631
+
632
+ if __name__ == "__main__":
633
+ opt = parse_opt()
634
+ main(opt)
yolov9/train_dual.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import yaml
16
+ from torch.optim import lr_scheduler
17
+ from tqdm import tqdm
18
+
19
+ FILE = Path(__file__).resolve()
20
+ ROOT = FILE.parents[0] # YOLO root directory
21
+ if str(ROOT) not in sys.path:
22
+ sys.path.append(str(ROOT)) # add ROOT to PATH
23
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
24
+
25
+ import val_dual as validate # for end-of-epoch mAP
26
+ from models.experimental import attempt_load
27
+ from models.yolo import Model
28
+ from utils.autoanchor import check_anchors
29
+ from utils.autobatch import check_train_batch_size
30
+ from utils.callbacks import Callbacks
31
+ from utils.dataloaders import create_dataloader
32
+ from utils.downloads import attempt_download, is_url
33
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
34
+ check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
35
+ get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
36
+ labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer,
37
+ yaml_save, one_flat_cycle)
38
+ from utils.loggers import Loggers
39
+ from utils.loggers.comet.comet_utils import check_comet_resume
40
+ from utils.loss_tal_dual import ComputeLoss
41
+ #from utils.loss_tal_dual import ComputeLossLH as ComputeLoss
42
+ #from utils.loss_tal_dual import ComputeLossLHCF as ComputeLoss
43
+ from utils.metrics import fitness
44
+ from utils.plots import plot_evolve
45
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
46
+ smart_resume, torch_distributed_zero_first)
47
+
48
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
49
+ RANK = int(os.getenv('RANK', -1))
50
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
51
+ GIT_INFO = None#check_git_info()
52
+
53
+
54
+ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
55
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
56
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
57
+ opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
58
+ callbacks.run('on_pretrain_routine_start')
59
+
60
+ # Directories
61
+ w = save_dir / 'weights' # weights dir
62
+ (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
63
+ last, best = w / 'last.pt', w / 'best.pt'
64
+
65
+ # Hyperparameters
66
+ if isinstance(hyp, str):
67
+ with open(hyp, errors='ignore') as f:
68
+ hyp = yaml.safe_load(f) # load hyps dict
69
+ LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
70
+ hyp['anchor_t'] = 5.0
71
+ opt.hyp = hyp.copy() # for saving hyps to checkpoints
72
+
73
+ # Save run settings
74
+ if not evolve:
75
+ yaml_save(save_dir / 'hyp.yaml', hyp)
76
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
77
+
78
+ # Loggers
79
+ data_dict = None
80
+ if RANK in {-1, 0}:
81
+ loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
82
+
83
+ # Register actions
84
+ for k in methods(loggers):
85
+ callbacks.register_action(k, callback=getattr(loggers, k))
86
+
87
+ # Process custom dataset artifact link
88
+ data_dict = loggers.remote_dataset
89
+ if resume: # If resuming runs from remote artifact
90
+ weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
91
+
92
+ # Config
93
+ plots = not evolve and not opt.noplots # create plots
94
+ cuda = device.type != 'cpu'
95
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
96
+ with torch_distributed_zero_first(LOCAL_RANK):
97
+ data_dict = data_dict or check_dataset(data) # check if None
98
+ train_path, val_path = data_dict['train'], data_dict['val']
99
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
100
+ names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
101
+ #is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
102
+ is_coco = isinstance(val_path, str) and val_path.endswith('val2017.txt') # COCO dataset
103
+
104
+ # Model
105
+ check_suffix(weights, '.pt') # check weights
106
+ pretrained = weights.endswith('.pt')
107
+ if pretrained:
108
+ with torch_distributed_zero_first(LOCAL_RANK):
109
+ weights = attempt_download(weights) # download if not found locally
110
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
111
+ model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
112
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
113
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
114
+ csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
115
+ model.load_state_dict(csd, strict=False) # load
116
+ LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
117
+ else:
118
+ model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
119
+ amp = check_amp(model) # check AMP
120
+
121
+ # Freeze
122
+ freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
123
+ for k, v in model.named_parameters():
124
+ # v.requires_grad = True # train all layers TODO: uncomment this line as in master
125
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
126
+ if any(x in k for x in freeze):
127
+ LOGGER.info(f'freezing {k}')
128
+ v.requires_grad = False
129
+
130
+ # Image size
131
+ gs = max(int(model.stride.max()), 32) # grid size (max stride)
132
+ imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
133
+
134
+ # Batch size
135
+ if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
136
+ batch_size = check_train_batch_size(model, imgsz, amp)
137
+ loggers.on_params_update({"batch_size": batch_size})
138
+
139
+ # Optimizer
140
+ nbs = 64 # nominal batch size
141
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
142
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
143
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
144
+
145
+ # Scheduler
146
+ if opt.cos_lr:
147
+ lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
148
+ elif opt.flat_cos_lr:
149
+ lf = one_flat_cycle(1, hyp['lrf'], epochs) # flat cosine 1->hyp['lrf']
150
+ elif opt.fixed_lr:
151
+ lf = lambda x: 1.0
152
+ else:
153
+ lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
154
+
155
+ # def lf(x): # saw
156
+ # return (1 - (x % 30) / 30) * (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']
157
+ #
158
+ # def lf(x): # triangle start at min
159
+ # return 2 * abs(x / 30 - math.floor(x / 30 + 1 / 2)) * (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']
160
+ #
161
+ # def lf(x): # triangle start at max
162
+ # return 2 * abs(x / 32 + .5 - math.floor(x / 32 + 1)) * (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']
163
+
164
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
165
+ # from utils.plots import plot_lr_scheduler; plot_lr_scheduler(optimizer, scheduler, epochs)
166
+
167
+ # EMA
168
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
169
+
170
+ # Resume
171
+ best_fitness, start_epoch = 0.0, 0
172
+ if pretrained:
173
+ if resume:
174
+ best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
175
+ del ckpt, csd
176
+
177
+ # DP mode
178
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
179
+ LOGGER.warning('WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.')
180
+ model = torch.nn.DataParallel(model)
181
+
182
+ # SyncBatchNorm
183
+ if opt.sync_bn and cuda and RANK != -1:
184
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
185
+ LOGGER.info('Using SyncBatchNorm()')
186
+
187
+ # Trainloader
188
+ train_loader, dataset = create_dataloader(train_path,
189
+ imgsz,
190
+ batch_size // WORLD_SIZE,
191
+ gs,
192
+ single_cls,
193
+ hyp=hyp,
194
+ augment=True,
195
+ cache=None if opt.cache == 'val' else opt.cache,
196
+ rect=opt.rect,
197
+ rank=LOCAL_RANK,
198
+ workers=workers,
199
+ image_weights=opt.image_weights,
200
+ close_mosaic=opt.close_mosaic != 0,
201
+ quad=opt.quad,
202
+ prefix=colorstr('train: '),
203
+ shuffle=True,
204
+ min_items=opt.min_items)
205
+ labels = np.concatenate(dataset.labels, 0)
206
+ mlc = int(labels[:, 0].max()) # max label class
207
+ assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
208
+
209
+ # Process 0
210
+ if RANK in {-1, 0}:
211
+ val_loader = create_dataloader(val_path,
212
+ imgsz,
213
+ batch_size // WORLD_SIZE * 2,
214
+ gs,
215
+ single_cls,
216
+ hyp=hyp,
217
+ cache=None if noval else opt.cache,
218
+ rect=True,
219
+ rank=-1,
220
+ workers=workers * 2,
221
+ pad=0.5,
222
+ prefix=colorstr('val: '))[0]
223
+
224
+ if not resume:
225
+ # if not opt.noautoanchor:
226
+ # check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
227
+ model.half().float() # pre-reduce anchor precision
228
+
229
+ callbacks.run('on_pretrain_routine_end', labels, names)
230
+
231
+ # DDP mode
232
+ if cuda and RANK != -1:
233
+ model = smart_DDP(model)
234
+
235
+ # Model attributes
236
+ nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
237
+ #hyp['box'] *= 3 / nl # scale to layers
238
+ #hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
239
+ #hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
240
+ hyp['label_smoothing'] = opt.label_smoothing
241
+ model.nc = nc # attach number of classes to model
242
+ model.hyp = hyp # attach hyperparameters to model
243
+ model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
244
+ model.names = names
245
+
246
+ # Start training
247
+ t0 = time.time()
248
+ nb = len(train_loader) # number of batches
249
+ nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
250
+ # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
251
+ last_opt_step = -1
252
+ maps = np.zeros(nc) # mAP per class
253
+ results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
254
+ scheduler.last_epoch = start_epoch - 1 # do not move
255
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
256
+ stopper, stop = EarlyStopping(patience=opt.patience), False
257
+ compute_loss = ComputeLoss(model) # init loss class
258
+ callbacks.run('on_train_start')
259
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
260
+ f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
261
+ f"Logging results to {colorstr('bold', save_dir)}\n"
262
+ f'Starting training for {epochs} epochs...')
263
+ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
264
+ callbacks.run('on_train_epoch_start')
265
+ model.train()
266
+
267
+ # Update image weights (optional, single-GPU only)
268
+ if opt.image_weights:
269
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
270
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
271
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
272
+ if epoch == (epochs - opt.close_mosaic):
273
+ LOGGER.info("Closing dataloader mosaic")
274
+ dataset.mosaic = False
275
+
276
+ # Update mosaic border (optional)
277
+ # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
278
+ # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
279
+
280
+ mloss = torch.zeros(3, device=device) # mean losses
281
+ if RANK != -1:
282
+ train_loader.sampler.set_epoch(epoch)
283
+ pbar = enumerate(train_loader)
284
+ LOGGER.info(('\n' + '%11s' * 7) % ('Epoch', 'GPU_mem', 'box_loss', 'cls_loss', 'dfl_loss', 'Instances', 'Size'))
285
+ if RANK in {-1, 0}:
286
+ pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
287
+ optimizer.zero_grad()
288
+ for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
289
+ callbacks.run('on_train_batch_start')
290
+ ni = i + nb * epoch # number integrated batches (since train start)
291
+ imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
292
+
293
+ # Warmup
294
+ if ni <= nw:
295
+ xi = [0, nw] # x interp
296
+ # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
297
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
298
+ for j, x in enumerate(optimizer.param_groups):
299
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
300
+ x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
301
+ if 'momentum' in x:
302
+ x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
303
+
304
+ # Multi-scale
305
+ if opt.multi_scale:
306
+ sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
307
+ sf = sz / max(imgs.shape[2:]) # scale factor
308
+ if sf != 1:
309
+ ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
310
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
311
+
312
+ # Forward
313
+ with torch.cuda.amp.autocast(amp):
314
+ pred = model(imgs) # forward
315
+ loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
316
+ if RANK != -1:
317
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
318
+ if opt.quad:
319
+ loss *= 4.
320
+
321
+ # Backward
322
+ scaler.scale(loss).backward()
323
+
324
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
325
+ if ni - last_opt_step >= accumulate:
326
+ scaler.unscale_(optimizer) # unscale gradients
327
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
328
+ scaler.step(optimizer) # optimizer.step
329
+ scaler.update()
330
+ optimizer.zero_grad()
331
+ if ema:
332
+ ema.update(model)
333
+ last_opt_step = ni
334
+
335
+ # Log
336
+ if RANK in {-1, 0}:
337
+ mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
338
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
339
+ pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
340
+ (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
341
+ callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths, list(mloss))
342
+ if callbacks.stop_training:
343
+ return
344
+ # end batch ------------------------------------------------------------------------------------------------
345
+
346
+ # Scheduler
347
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
348
+ scheduler.step()
349
+
350
+ if RANK in {-1, 0}:
351
+ # mAP
352
+ callbacks.run('on_train_epoch_end', epoch=epoch)
353
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
354
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
355
+ if not noval or final_epoch: # Calculate mAP
356
+ results, maps, _ = validate.run(data_dict,
357
+ batch_size=batch_size // WORLD_SIZE * 2,
358
+ imgsz=imgsz,
359
+ half=amp,
360
+ model=ema.ema,
361
+ single_cls=single_cls,
362
+ dataloader=val_loader,
363
+ save_dir=save_dir,
364
+ plots=False,
365
+ callbacks=callbacks,
366
+ compute_loss=compute_loss)
367
+
368
+ # Update best mAP
369
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
370
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
371
+ if fi > best_fitness:
372
+ best_fitness = fi
373
+ log_vals = list(mloss) + list(results) + lr
374
+ callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
375
+
376
+ # Save model
377
+ if (not nosave) or (final_epoch and not evolve): # if save
378
+ ckpt = {
379
+ 'epoch': epoch,
380
+ 'best_fitness': best_fitness,
381
+ 'model': deepcopy(de_parallel(model)).half(),
382
+ 'ema': deepcopy(ema.ema).half(),
383
+ 'updates': ema.updates,
384
+ 'optimizer': optimizer.state_dict(),
385
+ 'opt': vars(opt),
386
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
387
+ 'date': datetime.now().isoformat()}
388
+
389
+ # Save last, best and delete
390
+ torch.save(ckpt, last)
391
+ if best_fitness == fi:
392
+ torch.save(ckpt, best)
393
+ if opt.save_period > 0 and epoch % opt.save_period == 0:
394
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
395
+ del ckpt
396
+ callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
397
+
398
+ # EarlyStopping
399
+ if RANK != -1: # if DDP training
400
+ broadcast_list = [stop if RANK == 0 else None]
401
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
402
+ if RANK != 0:
403
+ stop = broadcast_list[0]
404
+ if stop:
405
+ break # must break all DDP ranks
406
+
407
+ # end epoch ----------------------------------------------------------------------------------------------------
408
+ # end training -----------------------------------------------------------------------------------------------------
409
+ if RANK in {-1, 0}:
410
+ LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
411
+ for f in last, best:
412
+ if f.exists():
413
+ strip_optimizer(f) # strip optimizers
414
+ if f is best:
415
+ LOGGER.info(f'\nValidating {f}...')
416
+ results, _, _ = validate.run(
417
+ data_dict,
418
+ batch_size=batch_size // WORLD_SIZE * 2,
419
+ imgsz=imgsz,
420
+ model=attempt_load(f, device).half(),
421
+ single_cls=single_cls,
422
+ dataloader=val_loader,
423
+ save_dir=save_dir,
424
+ save_json=is_coco,
425
+ verbose=True,
426
+ plots=plots,
427
+ callbacks=callbacks,
428
+ compute_loss=compute_loss) # val best model with plots
429
+ if is_coco:
430
+ callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
431
+
432
+ callbacks.run('on_train_end', last, best, epoch, results)
433
+
434
+ torch.cuda.empty_cache()
435
+ return results
436
+
437
+
438
+ def parse_opt(known=False):
439
+ parser = argparse.ArgumentParser()
440
+ # parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='initial weights path')
441
+ # parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
442
+ parser.add_argument('--weights', type=str, default='', help='initial weights path')
443
+ parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml path')
444
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
445
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
446
+ parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
447
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
448
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
449
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
450
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
451
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
452
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
453
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
454
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
455
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
456
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
457
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
458
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
459
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
460
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
461
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
462
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW', 'LION'], default='SGD', help='optimizer')
463
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
464
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
465
+ parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
466
+ parser.add_argument('--name', default='exp', help='save to project/name')
467
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
468
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
469
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
470
+ parser.add_argument('--flat-cos-lr', action='store_true', help='flat cosine LR scheduler')
471
+ parser.add_argument('--fixed-lr', action='store_true', help='fixed LR scheduler')
472
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
473
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
474
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
475
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
476
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
477
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
478
+ parser.add_argument('--min-items', type=int, default=0, help='Experimental')
479
+ parser.add_argument('--close-mosaic', type=int, default=0, help='Experimental')
480
+
481
+ # Logger arguments
482
+ parser.add_argument('--entity', default=None, help='Entity')
483
+ parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option')
484
+ parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval')
485
+ parser.add_argument('--artifact_alias', type=str, default='latest', help='Version of dataset artifact to use')
486
+
487
+ return parser.parse_known_args()[0] if known else parser.parse_args()
488
+
489
+
490
+ def main(opt, callbacks=Callbacks()):
491
+ # Checks
492
+ if RANK in {-1, 0}:
493
+ print_args(vars(opt))
494
+ #check_git_status()
495
+ #check_requirements()
496
+
497
+ # Resume (from specified or most recent last.pt)
498
+ if opt.resume and not check_comet_resume(opt) and not opt.evolve:
499
+ last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
500
+ opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
501
+ opt_data = opt.data # original dataset
502
+ if opt_yaml.is_file():
503
+ with open(opt_yaml, errors='ignore') as f:
504
+ d = yaml.safe_load(f)
505
+ else:
506
+ d = torch.load(last, map_location='cpu')['opt']
507
+ opt = argparse.Namespace(**d) # replace
508
+ opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
509
+ if is_url(opt_data):
510
+ opt.data = check_file(opt_data) # avoid HUB resume auth timeout
511
+ else:
512
+ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
513
+ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
514
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
515
+ if opt.evolve:
516
+ if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
517
+ opt.project = str(ROOT / 'runs/evolve')
518
+ opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
519
+ if opt.name == 'cfg':
520
+ opt.name = Path(opt.cfg).stem # use model.yaml as name
521
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
522
+
523
+ # DDP mode
524
+ device = select_device(opt.device, batch_size=opt.batch_size)
525
+ if LOCAL_RANK != -1:
526
+ msg = 'is not compatible with YOLO Multi-GPU DDP training'
527
+ assert not opt.image_weights, f'--image-weights {msg}'
528
+ assert not opt.evolve, f'--evolve {msg}'
529
+ assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
530
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
531
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
532
+ torch.cuda.set_device(LOCAL_RANK)
533
+ device = torch.device('cuda', LOCAL_RANK)
534
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
535
+
536
+ # Train
537
+ if not opt.evolve:
538
+ train(opt.hyp, opt, device, callbacks)
539
+
540
+ # Evolve hyperparameters (optional)
541
+ else:
542
+ # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
543
+ meta = {
544
+ 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
545
+ 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
546
+ 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
547
+ 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
548
+ 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
549
+ 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
550
+ 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
551
+ 'box': (1, 0.02, 0.2), # box loss gain
552
+ 'cls': (1, 0.2, 4.0), # cls loss gain
553
+ 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
554
+ 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
555
+ 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
556
+ 'iou_t': (0, 0.1, 0.7), # IoU training threshold
557
+ 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
558
+ 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
559
+ 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
560
+ 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
561
+ 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
562
+ 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
563
+ 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
564
+ 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
565
+ 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
566
+ 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
567
+ 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
568
+ 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
569
+ 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
570
+ 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
571
+ 'mixup': (1, 0.0, 1.0), # image mixup (probability)
572
+ 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
573
+
574
+ with open(opt.hyp, errors='ignore') as f:
575
+ hyp = yaml.safe_load(f) # load hyps dict
576
+ if 'anchors' not in hyp: # anchors commented in hyp.yaml
577
+ hyp['anchors'] = 3
578
+ if opt.noautoanchor:
579
+ del hyp['anchors'], meta['anchors']
580
+ opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
581
+ # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
582
+ evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
583
+ if opt.bucket:
584
+ os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
585
+
586
+ for _ in range(opt.evolve): # generations to evolve
587
+ if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
588
+ # Select parent(s)
589
+ parent = 'single' # parent selection method: 'single' or 'weighted'
590
+ x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
591
+ n = min(5, len(x)) # number of previous results to consider
592
+ x = x[np.argsort(-fitness(x))][:n] # top n mutations
593
+ w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
594
+ if parent == 'single' or len(x) == 1:
595
+ # x = x[random.randint(0, n - 1)] # random selection
596
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
597
+ elif parent == 'weighted':
598
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
599
+
600
+ # Mutate
601
+ mp, s = 0.8, 0.2 # mutation probability, sigma
602
+ npr = np.random
603
+ npr.seed(int(time.time()))
604
+ g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
605
+ ng = len(meta)
606
+ v = np.ones(ng)
607
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
608
+ v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
609
+ for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
610
+ hyp[k] = float(x[i + 7] * v[i]) # mutate
611
+
612
+ # Constrain to limits
613
+ for k, v in meta.items():
614
+ hyp[k] = max(hyp[k], v[1]) # lower limit
615
+ hyp[k] = min(hyp[k], v[2]) # upper limit
616
+ hyp[k] = round(hyp[k], 5) # significant digits
617
+
618
+ # Train mutation
619
+ results = train(hyp.copy(), opt, device, callbacks)
620
+ callbacks = Callbacks()
621
+ # Write mutation results
622
+ keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
623
+ 'val/obj_loss', 'val/cls_loss')
624
+ print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)
625
+
626
+ # Plot results
627
+ plot_evolve(evolve_csv)
628
+ LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
629
+ f"Results saved to {colorstr('bold', save_dir)}\n"
630
+ f'Usage example: $ python train.py --hyp {evolve_yaml}')
631
+
632
+
633
+ def run(**kwargs):
634
+ # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolo.pt')
635
+ opt = parse_opt(True)
636
+ for k, v in kwargs.items():
637
+ setattr(opt, k, v)
638
+ main(opt)
639
+ return opt
640
+
641
+
642
+ if __name__ == "__main__":
643
+ opt = parse_opt()
644
+ main(opt)
yolov9/train_triple.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import yaml
16
+ from torch.optim import lr_scheduler
17
+ from tqdm import tqdm
18
+
19
+ FILE = Path(__file__).resolve()
20
+ ROOT = FILE.parents[0] # YOLO root directory
21
+ if str(ROOT) not in sys.path:
22
+ sys.path.append(str(ROOT)) # add ROOT to PATH
23
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
24
+
25
+ import val_triple as validate # for end-of-epoch mAP
26
+ from models.experimental import attempt_load
27
+ from models.yolo import Model
28
+ from utils.autoanchor import check_anchors
29
+ from utils.autobatch import check_train_batch_size
30
+ from utils.callbacks import Callbacks
31
+ from utils.dataloaders import create_dataloader
32
+ from utils.downloads import attempt_download, is_url
33
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
34
+ check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
35
+ get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
36
+ labels_to_image_weights, methods, one_cycle, print_args, print_mutation, strip_optimizer,
37
+ yaml_save)
38
+ from utils.loggers import Loggers
39
+ from utils.loggers.comet.comet_utils import check_comet_resume
40
+ from utils.loss_tal_triple import ComputeLoss
41
+ from utils.metrics import fitness
42
+ from utils.plots import plot_evolve
43
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
44
+ smart_resume, torch_distributed_zero_first)
45
+
46
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
47
+ RANK = int(os.getenv('RANK', -1))
48
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
49
+ GIT_INFO = None#check_git_info()
50
+
51
+
52
+ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
53
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze = \
54
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
55
+ opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze
56
+ callbacks.run('on_pretrain_routine_start')
57
+
58
+ # Directories
59
+ w = save_dir / 'weights' # weights dir
60
+ (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
61
+ last, best = w / 'last.pt', w / 'best.pt'
62
+
63
+ # Hyperparameters
64
+ if isinstance(hyp, str):
65
+ with open(hyp, errors='ignore') as f:
66
+ hyp = yaml.safe_load(f) # load hyps dict
67
+ LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
68
+ hyp['anchor_t'] = 5.0
69
+ opt.hyp = hyp.copy() # for saving hyps to checkpoints
70
+
71
+ # Save run settings
72
+ if not evolve:
73
+ yaml_save(save_dir / 'hyp.yaml', hyp)
74
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
75
+
76
+ # Loggers
77
+ data_dict = None
78
+ if RANK in {-1, 0}:
79
+ loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
80
+
81
+ # Register actions
82
+ for k in methods(loggers):
83
+ callbacks.register_action(k, callback=getattr(loggers, k))
84
+
85
+ # Process custom dataset artifact link
86
+ data_dict = loggers.remote_dataset
87
+ if resume: # If resuming runs from remote artifact
88
+ weights, epochs, hyp, batch_size = opt.weights, opt.epochs, opt.hyp, opt.batch_size
89
+
90
+ # Config
91
+ plots = not evolve and not opt.noplots # create plots
92
+ cuda = device.type != 'cpu'
93
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
94
+ with torch_distributed_zero_first(LOCAL_RANK):
95
+ data_dict = data_dict or check_dataset(data) # check if None
96
+ train_path, val_path = data_dict['train'], data_dict['val']
97
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
98
+ names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
99
+ #is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
100
+ is_coco = isinstance(val_path, str) and val_path.endswith('val2017.txt') # COCO dataset
101
+
102
+ # Model
103
+ check_suffix(weights, '.pt') # check weights
104
+ pretrained = weights.endswith('.pt')
105
+ if pretrained:
106
+ with torch_distributed_zero_first(LOCAL_RANK):
107
+ weights = attempt_download(weights) # download if not found locally
108
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
109
+ model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
110
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
111
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
112
+ csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
113
+ model.load_state_dict(csd, strict=False) # load
114
+ LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
115
+ else:
116
+ model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
117
+ amp = check_amp(model) # check AMP
118
+
119
+ # Freeze
120
+ freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
121
+ for k, v in model.named_parameters():
122
+ # v.requires_grad = True # train all layers TODO: uncomment this line as in master
123
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
124
+ if any(x in k for x in freeze):
125
+ LOGGER.info(f'freezing {k}')
126
+ v.requires_grad = False
127
+
128
+ # Image size
129
+ gs = max(int(model.stride.max()), 32) # grid size (max stride)
130
+ imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
131
+
132
+ # Batch size
133
+ if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
134
+ batch_size = check_train_batch_size(model, imgsz, amp)
135
+ loggers.on_params_update({"batch_size": batch_size})
136
+
137
+ # Optimizer
138
+ nbs = 64 # nominal batch size
139
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
140
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
141
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
142
+
143
+ # Scheduler
144
+ if opt.cos_lr:
145
+ lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
146
+ else:
147
+ lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
148
+
149
+ # def lf(x): # saw
150
+ # return (1 - (x % 30) / 30) * (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']
151
+ #
152
+ # def lf(x): # triangle start at min
153
+ # return 2 * abs(x / 30 - math.floor(x / 30 + 1 / 2)) * (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']
154
+ #
155
+ # def lf(x): # triangle start at max
156
+ # return 2 * abs(x / 32 + .5 - math.floor(x / 32 + 1)) * (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf']
157
+
158
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
159
+ # from utils.plots import plot_lr_scheduler; plot_lr_scheduler(optimizer, scheduler, epochs)
160
+
161
+ # EMA
162
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
163
+
164
+ # Resume
165
+ best_fitness, start_epoch = 0.0, 0
166
+ if pretrained:
167
+ if resume:
168
+ best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
169
+ del ckpt, csd
170
+
171
+ # DP mode
172
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
173
+ LOGGER.warning('WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.')
174
+ model = torch.nn.DataParallel(model)
175
+
176
+ # SyncBatchNorm
177
+ if opt.sync_bn and cuda and RANK != -1:
178
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
179
+ LOGGER.info('Using SyncBatchNorm()')
180
+
181
+ # Trainloader
182
+ train_loader, dataset = create_dataloader(train_path,
183
+ imgsz,
184
+ batch_size // WORLD_SIZE,
185
+ gs,
186
+ single_cls,
187
+ hyp=hyp,
188
+ augment=True,
189
+ cache=None if opt.cache == 'val' else opt.cache,
190
+ rect=opt.rect,
191
+ rank=LOCAL_RANK,
192
+ workers=workers,
193
+ image_weights=opt.image_weights,
194
+ close_mosaic=opt.close_mosaic != 0,
195
+ quad=opt.quad,
196
+ prefix=colorstr('train: '),
197
+ shuffle=True,
198
+ min_items=opt.min_items)
199
+ labels = np.concatenate(dataset.labels, 0)
200
+ mlc = int(labels[:, 0].max()) # max label class
201
+ assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
202
+
203
+ # Process 0
204
+ if RANK in {-1, 0}:
205
+ val_loader = create_dataloader(val_path,
206
+ imgsz,
207
+ batch_size // WORLD_SIZE * 2,
208
+ gs,
209
+ single_cls,
210
+ hyp=hyp,
211
+ cache=None if noval else opt.cache,
212
+ rect=True,
213
+ rank=-1,
214
+ workers=workers * 2,
215
+ pad=0.5,
216
+ prefix=colorstr('val: '))[0]
217
+
218
+ if not resume:
219
+ # if not opt.noautoanchor:
220
+ # check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
221
+ model.half().float() # pre-reduce anchor precision
222
+
223
+ callbacks.run('on_pretrain_routine_end', labels, names)
224
+
225
+ # DDP mode
226
+ if cuda and RANK != -1:
227
+ model = smart_DDP(model)
228
+
229
+ # Model attributes
230
+ nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
231
+ #hyp['box'] *= 3 / nl # scale to layers
232
+ #hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
233
+ #hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
234
+ hyp['label_smoothing'] = opt.label_smoothing
235
+ model.nc = nc # attach number of classes to model
236
+ model.hyp = hyp # attach hyperparameters to model
237
+ model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
238
+ model.names = names
239
+
240
+ # Start training
241
+ t0 = time.time()
242
+ nb = len(train_loader) # number of batches
243
+ nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
244
+ # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
245
+ last_opt_step = -1
246
+ maps = np.zeros(nc) # mAP per class
247
+ results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
248
+ scheduler.last_epoch = start_epoch - 1 # do not move
249
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
250
+ stopper, stop = EarlyStopping(patience=opt.patience), False
251
+ compute_loss = ComputeLoss(model) # init loss class
252
+ callbacks.run('on_train_start')
253
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
254
+ f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
255
+ f"Logging results to {colorstr('bold', save_dir)}\n"
256
+ f'Starting training for {epochs} epochs...')
257
+ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
258
+ callbacks.run('on_train_epoch_start')
259
+ model.train()
260
+
261
+ # Update image weights (optional, single-GPU only)
262
+ if opt.image_weights:
263
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
264
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
265
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
266
+ if epoch == (epochs - opt.close_mosaic):
267
+ LOGGER.info("Closing dataloader mosaic")
268
+ dataset.mosaic = False
269
+
270
+ # Update mosaic border (optional)
271
+ # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
272
+ # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
273
+
274
+ mloss = torch.zeros(3, device=device) # mean losses
275
+ if RANK != -1:
276
+ train_loader.sampler.set_epoch(epoch)
277
+ pbar = enumerate(train_loader)
278
+ LOGGER.info(('\n' + '%11s' * 7) % ('Epoch', 'GPU_mem', 'box_loss', 'cls_loss', 'dfl_loss', 'Instances', 'Size'))
279
+ if RANK in {-1, 0}:
280
+ pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
281
+ optimizer.zero_grad()
282
+ for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
283
+ callbacks.run('on_train_batch_start')
284
+ ni = i + nb * epoch # number integrated batches (since train start)
285
+ imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
286
+
287
+ # Warmup
288
+ if ni <= nw:
289
+ xi = [0, nw] # x interp
290
+ # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
291
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
292
+ for j, x in enumerate(optimizer.param_groups):
293
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
294
+ x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
295
+ if 'momentum' in x:
296
+ x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
297
+
298
+ # Multi-scale
299
+ if opt.multi_scale:
300
+ sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
301
+ sf = sz / max(imgs.shape[2:]) # scale factor
302
+ if sf != 1:
303
+ ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
304
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
305
+
306
+ # Forward
307
+ with torch.cuda.amp.autocast(amp):
308
+ pred = model(imgs) # forward
309
+ loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
310
+ if RANK != -1:
311
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
312
+ if opt.quad:
313
+ loss *= 4.
314
+
315
+ # Backward
316
+ scaler.scale(loss).backward()
317
+
318
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
319
+ if ni - last_opt_step >= accumulate:
320
+ scaler.unscale_(optimizer) # unscale gradients
321
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
322
+ scaler.step(optimizer) # optimizer.step
323
+ scaler.update()
324
+ optimizer.zero_grad()
325
+ if ema:
326
+ ema.update(model)
327
+ last_opt_step = ni
328
+
329
+ # Log
330
+ if RANK in {-1, 0}:
331
+ mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
332
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
333
+ pbar.set_description(('%11s' * 2 + '%11.4g' * 5) %
334
+ (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
335
+ callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths, list(mloss))
336
+ if callbacks.stop_training:
337
+ return
338
+ # end batch ------------------------------------------------------------------------------------------------
339
+
340
+ # Scheduler
341
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
342
+ scheduler.step()
343
+
344
+ if RANK in {-1, 0}:
345
+ # mAP
346
+ callbacks.run('on_train_epoch_end', epoch=epoch)
347
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
348
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
349
+ if not noval or final_epoch: # Calculate mAP
350
+ results, maps, _ = validate.run(data_dict,
351
+ batch_size=batch_size // WORLD_SIZE * 2,
352
+ imgsz=imgsz,
353
+ half=amp,
354
+ model=ema.ema,
355
+ single_cls=single_cls,
356
+ dataloader=val_loader,
357
+ save_dir=save_dir,
358
+ plots=False,
359
+ callbacks=callbacks,
360
+ compute_loss=compute_loss)
361
+
362
+ # Update best mAP
363
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
364
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
365
+ if fi > best_fitness:
366
+ best_fitness = fi
367
+ log_vals = list(mloss) + list(results) + lr
368
+ callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
369
+
370
+ # Save model
371
+ if (not nosave) or (final_epoch and not evolve): # if save
372
+ ckpt = {
373
+ 'epoch': epoch,
374
+ 'best_fitness': best_fitness,
375
+ 'model': deepcopy(de_parallel(model)).half(),
376
+ 'ema': deepcopy(ema.ema).half(),
377
+ 'updates': ema.updates,
378
+ 'optimizer': optimizer.state_dict(),
379
+ 'opt': vars(opt),
380
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
381
+ 'date': datetime.now().isoformat()}
382
+
383
+ # Save last, best and delete
384
+ torch.save(ckpt, last)
385
+ if best_fitness == fi:
386
+ torch.save(ckpt, best)
387
+ if opt.save_period > 0 and epoch % opt.save_period == 0:
388
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
389
+ del ckpt
390
+ callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
391
+
392
+ # EarlyStopping
393
+ if RANK != -1: # if DDP training
394
+ broadcast_list = [stop if RANK == 0 else None]
395
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
396
+ if RANK != 0:
397
+ stop = broadcast_list[0]
398
+ if stop:
399
+ break # must break all DDP ranks
400
+
401
+ # end epoch ----------------------------------------------------------------------------------------------------
402
+ # end training -----------------------------------------------------------------------------------------------------
403
+ if RANK in {-1, 0}:
404
+ LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
405
+ for f in last, best:
406
+ if f.exists():
407
+ strip_optimizer(f) # strip optimizers
408
+ if f is best:
409
+ LOGGER.info(f'\nValidating {f}...')
410
+ results, _, _ = validate.run(
411
+ data_dict,
412
+ batch_size=batch_size // WORLD_SIZE * 2,
413
+ imgsz=imgsz,
414
+ model=attempt_load(f, device).half(),
415
+ single_cls=single_cls,
416
+ dataloader=val_loader,
417
+ save_dir=save_dir,
418
+ save_json=is_coco,
419
+ verbose=True,
420
+ plots=plots,
421
+ callbacks=callbacks,
422
+ compute_loss=compute_loss) # val best model with plots
423
+ if is_coco:
424
+ callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
425
+
426
+ callbacks.run('on_train_end', last, best, epoch, results)
427
+
428
+ torch.cuda.empty_cache()
429
+ return results
430
+
431
+
432
+ def parse_opt(known=False):
433
+ parser = argparse.ArgumentParser()
434
+ # parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='initial weights path')
435
+ # parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
436
+ parser.add_argument('--weights', type=str, default='', help='initial weights path')
437
+ parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml path')
438
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
439
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
440
+ parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
441
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
442
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
443
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
444
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
445
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
446
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
447
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
448
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
449
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
450
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
451
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
452
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
453
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
454
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
455
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
456
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW', 'LION'], default='SGD', help='optimizer')
457
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
458
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
459
+ parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
460
+ parser.add_argument('--name', default='exp', help='save to project/name')
461
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
462
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
463
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
464
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
465
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
466
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
467
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
468
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
469
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
470
+ parser.add_argument('--min-items', type=int, default=0, help='Experimental')
471
+ parser.add_argument('--close-mosaic', type=int, default=0, help='Experimental')
472
+
473
+ # Logger arguments
474
+ parser.add_argument('--entity', default=None, help='Entity')
475
+ parser.add_argument('--upload_dataset', nargs='?', const=True, default=False, help='Upload data, "val" option')
476
+ parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval')
477
+ parser.add_argument('--artifact_alias', type=str, default='latest', help='Version of dataset artifact to use')
478
+
479
+ return parser.parse_known_args()[0] if known else parser.parse_args()
480
+
481
+
482
+ def main(opt, callbacks=Callbacks()):
483
+ # Checks
484
+ if RANK in {-1, 0}:
485
+ print_args(vars(opt))
486
+ #check_git_status()
487
+ #check_requirements()
488
+
489
+ # Resume (from specified or most recent last.pt)
490
+ if opt.resume and not check_comet_resume(opt) and not opt.evolve:
491
+ last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
492
+ opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
493
+ opt_data = opt.data # original dataset
494
+ if opt_yaml.is_file():
495
+ with open(opt_yaml, errors='ignore') as f:
496
+ d = yaml.safe_load(f)
497
+ else:
498
+ d = torch.load(last, map_location='cpu')['opt']
499
+ opt = argparse.Namespace(**d) # replace
500
+ opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
501
+ if is_url(opt_data):
502
+ opt.data = check_file(opt_data) # avoid HUB resume auth timeout
503
+ else:
504
+ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
505
+ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
506
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
507
+ if opt.evolve:
508
+ if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
509
+ opt.project = str(ROOT / 'runs/evolve')
510
+ opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
511
+ if opt.name == 'cfg':
512
+ opt.name = Path(opt.cfg).stem # use model.yaml as name
513
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
514
+
515
+ # DDP mode
516
+ device = select_device(opt.device, batch_size=opt.batch_size)
517
+ if LOCAL_RANK != -1:
518
+ msg = 'is not compatible with YOLO Multi-GPU DDP training'
519
+ assert not opt.image_weights, f'--image-weights {msg}'
520
+ assert not opt.evolve, f'--evolve {msg}'
521
+ assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
522
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
523
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
524
+ torch.cuda.set_device(LOCAL_RANK)
525
+ device = torch.device('cuda', LOCAL_RANK)
526
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
527
+
528
+ # Train
529
+ if not opt.evolve:
530
+ train(opt.hyp, opt, device, callbacks)
531
+
532
+ # Evolve hyperparameters (optional)
533
+ else:
534
+ # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
535
+ meta = {
536
+ 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
537
+ 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
538
+ 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
539
+ 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
540
+ 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
541
+ 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
542
+ 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
543
+ 'box': (1, 0.02, 0.2), # box loss gain
544
+ 'cls': (1, 0.2, 4.0), # cls loss gain
545
+ 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
546
+ 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
547
+ 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
548
+ 'iou_t': (0, 0.1, 0.7), # IoU training threshold
549
+ 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
550
+ 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
551
+ 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
552
+ 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
553
+ 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
554
+ 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
555
+ 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
556
+ 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
557
+ 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
558
+ 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
559
+ 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
560
+ 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
561
+ 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
562
+ 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
563
+ 'mixup': (1, 0.0, 1.0), # image mixup (probability)
564
+ 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
565
+
566
+ with open(opt.hyp, errors='ignore') as f:
567
+ hyp = yaml.safe_load(f) # load hyps dict
568
+ if 'anchors' not in hyp: # anchors commented in hyp.yaml
569
+ hyp['anchors'] = 3
570
+ if opt.noautoanchor:
571
+ del hyp['anchors'], meta['anchors']
572
+ opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
573
+ # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
574
+ evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
575
+ if opt.bucket:
576
+ os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
577
+
578
+ for _ in range(opt.evolve): # generations to evolve
579
+ if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
580
+ # Select parent(s)
581
+ parent = 'single' # parent selection method: 'single' or 'weighted'
582
+ x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
583
+ n = min(5, len(x)) # number of previous results to consider
584
+ x = x[np.argsort(-fitness(x))][:n] # top n mutations
585
+ w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
586
+ if parent == 'single' or len(x) == 1:
587
+ # x = x[random.randint(0, n - 1)] # random selection
588
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
589
+ elif parent == 'weighted':
590
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
591
+
592
+ # Mutate
593
+ mp, s = 0.8, 0.2 # mutation probability, sigma
594
+ npr = np.random
595
+ npr.seed(int(time.time()))
596
+ g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
597
+ ng = len(meta)
598
+ v = np.ones(ng)
599
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
600
+ v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
601
+ for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
602
+ hyp[k] = float(x[i + 7] * v[i]) # mutate
603
+
604
+ # Constrain to limits
605
+ for k, v in meta.items():
606
+ hyp[k] = max(hyp[k], v[1]) # lower limit
607
+ hyp[k] = min(hyp[k], v[2]) # upper limit
608
+ hyp[k] = round(hyp[k], 5) # significant digits
609
+
610
+ # Train mutation
611
+ results = train(hyp.copy(), opt, device, callbacks)
612
+ callbacks = Callbacks()
613
+ # Write mutation results
614
+ keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss',
615
+ 'val/obj_loss', 'val/cls_loss')
616
+ print_mutation(keys, results, hyp.copy(), save_dir, opt.bucket)
617
+
618
+ # Plot results
619
+ plot_evolve(evolve_csv)
620
+ LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
621
+ f"Results saved to {colorstr('bold', save_dir)}\n"
622
+ f'Usage example: $ python train.py --hyp {evolve_yaml}')
623
+
624
+
625
+ def run(**kwargs):
626
+ # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolo.pt')
627
+ opt = parse_opt(True)
628
+ for k, v in kwargs.items():
629
+ setattr(opt, k, v)
630
+ main(opt)
631
+ return opt
632
+
633
+
634
+ if __name__ == "__main__":
635
+ opt = parse_opt()
636
+ main(opt)
yolov9/utils/__init__.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import platform
3
+ import threading
4
+
5
+
6
+ def emojis(str=''):
7
+ # Return platform-dependent emoji-safe version of string
8
+ return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
9
+
10
+
11
+ class TryExcept(contextlib.ContextDecorator):
12
+ # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
13
+ def __init__(self, msg=''):
14
+ self.msg = msg
15
+
16
+ def __enter__(self):
17
+ pass
18
+
19
+ def __exit__(self, exc_type, value, traceback):
20
+ if value:
21
+ print(emojis(f"{self.msg}{': ' if self.msg else ''}{value}"))
22
+ return True
23
+
24
+
25
+ def threaded(func):
26
+ # Multi-threads a target function and returns thread. Usage: @threaded decorator
27
+ def wrapper(*args, **kwargs):
28
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
29
+ thread.start()
30
+ return thread
31
+
32
+ return wrapper
33
+
34
+
35
+ def join_threads(verbose=False):
36
+ # Join all daemon threads, i.e. atexit.register(lambda: join_threads())
37
+ main_thread = threading.current_thread()
38
+ for t in threading.enumerate():
39
+ if t is not main_thread:
40
+ if verbose:
41
+ print(f'Joining thread {t.name}')
42
+ t.join()
43
+
44
+
45
+ def notebook_init(verbose=True):
46
+ # Check system software and hardware
47
+ print('Checking setup...')
48
+
49
+ import os
50
+ import shutil
51
+
52
+ from utils.general import check_font, check_requirements, is_colab
53
+ from utils.torch_utils import select_device # imports
54
+
55
+ check_font()
56
+
57
+ import psutil
58
+ from IPython import display # to display images and clear console output
59
+
60
+ if is_colab():
61
+ shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory
62
+
63
+ # System info
64
+ if verbose:
65
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
66
+ ram = psutil.virtual_memory().total
67
+ total, used, free = shutil.disk_usage("/")
68
+ display.clear_output()
69
+ s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)'
70
+ else:
71
+ s = ''
72
+
73
+ select_device(newline=False)
74
+ print(emojis(f'Setup complete ✅ {s}'))
75
+ return display
yolov9/utils/activations.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SiLU(nn.Module):
7
+ # SiLU activation https://arxiv.org/pdf/1606.08415.pdf
8
+ @staticmethod
9
+ def forward(x):
10
+ return x * torch.sigmoid(x)
11
+
12
+
13
+ class Hardswish(nn.Module):
14
+ # Hard-SiLU activation
15
+ @staticmethod
16
+ def forward(x):
17
+ # return x * F.hardsigmoid(x) # for TorchScript and CoreML
18
+ return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX
19
+
20
+
21
+ class Mish(nn.Module):
22
+ # Mish activation https://github.com/digantamisra98/Mish
23
+ @staticmethod
24
+ def forward(x):
25
+ return x * F.softplus(x).tanh()
26
+
27
+
28
+ class MemoryEfficientMish(nn.Module):
29
+ # Mish activation memory-efficient
30
+ class F(torch.autograd.Function):
31
+
32
+ @staticmethod
33
+ def forward(ctx, x):
34
+ ctx.save_for_backward(x)
35
+ return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x)))
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_output):
39
+ x = ctx.saved_tensors[0]
40
+ sx = torch.sigmoid(x)
41
+ fx = F.softplus(x).tanh()
42
+ return grad_output * (fx + x * sx * (1 - fx * fx))
43
+
44
+ def forward(self, x):
45
+ return self.F.apply(x)
46
+
47
+
48
+ class FReLU(nn.Module):
49
+ # FReLU activation https://arxiv.org/abs/2007.11824
50
+ def __init__(self, c1, k=3): # ch_in, kernel
51
+ super().__init__()
52
+ self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)
53
+ self.bn = nn.BatchNorm2d(c1)
54
+
55
+ def forward(self, x):
56
+ return torch.max(x, self.bn(self.conv(x)))
57
+
58
+
59
+ class AconC(nn.Module):
60
+ r""" ACON activation (activate or not)
61
+ AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter
62
+ according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
63
+ """
64
+
65
+ def __init__(self, c1):
66
+ super().__init__()
67
+ self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
68
+ self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
69
+ self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))
70
+
71
+ def forward(self, x):
72
+ dpx = (self.p1 - self.p2) * x
73
+ return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x
74
+
75
+
76
+ class MetaAconC(nn.Module):
77
+ r""" ACON activation (activate or not)
78
+ MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network
79
+ according to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>.
80
+ """
81
+
82
+ def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r
83
+ super().__init__()
84
+ c2 = max(r, c1 // r)
85
+ self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))
86
+ self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))
87
+ self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)
88
+ self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)
89
+ # self.bn1 = nn.BatchNorm2d(c2)
90
+ # self.bn2 = nn.BatchNorm2d(c1)
91
+
92
+ def forward(self, x):
93
+ y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)
94
+ # batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891
95
+ # beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable
96
+ beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed
97
+ dpx = (self.p1 - self.p2) * x
98
+ return dpx * torch.sigmoid(beta * dpx) + self.p2 * x
yolov9/utils/augmentations.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as T
8
+ import torchvision.transforms.functional as TF
9
+
10
+ from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box, xywhn2xyxy
11
+ from utils.metrics import bbox_ioa
12
+
13
+ IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
14
+ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
15
+
16
+
17
+ class Albumentations:
18
+ # YOLOv5 Albumentations class (optional, only used if package is installed)
19
+ def __init__(self, size=640):
20
+ self.transform = None
21
+ prefix = colorstr('albumentations: ')
22
+ try:
23
+ import albumentations as A
24
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
25
+
26
+ T = [
27
+ A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
28
+ A.Blur(p=0.01),
29
+ A.MedianBlur(p=0.01),
30
+ A.ToGray(p=0.01),
31
+ A.CLAHE(p=0.01),
32
+ A.RandomBrightnessContrast(p=0.0),
33
+ A.RandomGamma(p=0.0),
34
+ A.ImageCompression(quality_lower=75, p=0.0)] # transforms
35
+ self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
36
+
37
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
38
+ except ImportError: # package not installed, skip
39
+ pass
40
+ except Exception as e:
41
+ LOGGER.info(f'{prefix}{e}')
42
+
43
+ def __call__(self, im, labels, p=1.0):
44
+ if self.transform and random.random() < p:
45
+ new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed
46
+ im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])])
47
+ return im, labels
48
+
49
+
50
+ def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
51
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
52
+ return TF.normalize(x, mean, std, inplace=inplace)
53
+
54
+
55
+ def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
56
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
57
+ for i in range(3):
58
+ x[:, i] = x[:, i] * std[i] + mean[i]
59
+ return x
60
+
61
+
62
+ def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
63
+ # HSV color-space augmentation
64
+ if hgain or sgain or vgain:
65
+ r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
66
+ hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
67
+ dtype = im.dtype # uint8
68
+
69
+ x = np.arange(0, 256, dtype=r.dtype)
70
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
71
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
72
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
73
+
74
+ im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
75
+ cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
76
+
77
+
78
+ def hist_equalize(im, clahe=True, bgr=False):
79
+ # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255
80
+ yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV)
81
+ if clahe:
82
+ c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
83
+ yuv[:, :, 0] = c.apply(yuv[:, :, 0])
84
+ else:
85
+ yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram
86
+ return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB
87
+
88
+
89
+ def replicate(im, labels):
90
+ # Replicate labels
91
+ h, w = im.shape[:2]
92
+ boxes = labels[:, 1:].astype(int)
93
+ x1, y1, x2, y2 = boxes.T
94
+ s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels)
95
+ for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices
96
+ x1b, y1b, x2b, y2b = boxes[i]
97
+ bh, bw = y2b - y1b, x2b - x1b
98
+ yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y
99
+ x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh]
100
+ im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax]
101
+ labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0)
102
+
103
+ return im, labels
104
+
105
+
106
+ def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
107
+ # Resize and pad image while meeting stride-multiple constraints
108
+ shape = im.shape[:2] # current shape [height, width]
109
+ if isinstance(new_shape, int):
110
+ new_shape = (new_shape, new_shape)
111
+
112
+ # Scale ratio (new / old)
113
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
114
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
115
+ r = min(r, 1.0)
116
+
117
+ # Compute padding
118
+ ratio = r, r # width, height ratios
119
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
120
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
121
+ if auto: # minimum rectangle
122
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
123
+ elif scaleFill: # stretch
124
+ dw, dh = 0.0, 0.0
125
+ new_unpad = (new_shape[1], new_shape[0])
126
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
127
+
128
+ dw /= 2 # divide padding into 2 sides
129
+ dh /= 2
130
+
131
+ if shape[::-1] != new_unpad: # resize
132
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
133
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
134
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
135
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
136
+ return im, ratio, (dw, dh)
137
+
138
+
139
+ def random_perspective(im,
140
+ targets=(),
141
+ segments=(),
142
+ degrees=10,
143
+ translate=.1,
144
+ scale=.1,
145
+ shear=10,
146
+ perspective=0.0,
147
+ border=(0, 0)):
148
+ # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
149
+ # targets = [cls, xyxy]
150
+
151
+ height = im.shape[0] + border[0] * 2 # shape(h,w,c)
152
+ width = im.shape[1] + border[1] * 2
153
+
154
+ # Center
155
+ C = np.eye(3)
156
+ C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
157
+ C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
158
+
159
+ # Perspective
160
+ P = np.eye(3)
161
+ P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
162
+ P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
163
+
164
+ # Rotation and Scale
165
+ R = np.eye(3)
166
+ a = random.uniform(-degrees, degrees)
167
+ # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
168
+ s = random.uniform(1 - scale, 1 + scale)
169
+ # s = 2 ** random.uniform(-scale, scale)
170
+ R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
171
+
172
+ # Shear
173
+ S = np.eye(3)
174
+ S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
175
+ S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
176
+
177
+ # Translation
178
+ T = np.eye(3)
179
+ T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
180
+ T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
181
+
182
+ # Combined rotation matrix
183
+ M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
184
+ if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
185
+ if perspective:
186
+ im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
187
+ else: # affine
188
+ im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
189
+
190
+ # Visualize
191
+ # import matplotlib.pyplot as plt
192
+ # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
193
+ # ax[0].imshow(im[:, :, ::-1]) # base
194
+ # ax[1].imshow(im2[:, :, ::-1]) # warped
195
+
196
+ # Transform label coordinates
197
+ n = len(targets)
198
+ if n:
199
+ use_segments = any(x.any() for x in segments)
200
+ new = np.zeros((n, 4))
201
+ if use_segments: # warp segments
202
+ segments = resample_segments(segments) # upsample
203
+ for i, segment in enumerate(segments):
204
+ xy = np.ones((len(segment), 3))
205
+ xy[:, :2] = segment
206
+ xy = xy @ M.T # transform
207
+ xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine
208
+
209
+ # clip
210
+ new[i] = segment2box(xy, width, height)
211
+
212
+ else: # warp boxes
213
+ xy = np.ones((n * 4, 3))
214
+ xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
215
+ xy = xy @ M.T # transform
216
+ xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
217
+
218
+ # create new boxes
219
+ x = xy[:, [0, 2, 4, 6]]
220
+ y = xy[:, [1, 3, 5, 7]]
221
+ new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
222
+
223
+ # clip
224
+ new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
225
+ new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
226
+
227
+ # filter candidates
228
+ i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
229
+ targets = targets[i]
230
+ targets[:, 1:5] = new[i]
231
+
232
+ return im, targets
233
+
234
+
235
+ def copy_paste(im, labels, segments, p=0.5):
236
+ # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy)
237
+ n = len(segments)
238
+ if p and n:
239
+ h, w, c = im.shape # height, width, channels
240
+ im_new = np.zeros(im.shape, np.uint8)
241
+
242
+ # calculate ioa first then select indexes randomly
243
+ boxes = np.stack([w - labels[:, 3], labels[:, 2], w - labels[:, 1], labels[:, 4]], axis=-1) # (n, 4)
244
+ ioa = bbox_ioa(boxes, labels[:, 1:5]) # intersection over area
245
+ indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
246
+ n = len(indexes)
247
+ for j in random.sample(list(indexes), k=round(p * n)):
248
+ l, box, s = labels[j], boxes[j], segments[j]
249
+ labels = np.concatenate((labels, [[l[0], *box]]), 0)
250
+ segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1))
251
+ cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (1, 1, 1), cv2.FILLED)
252
+
253
+ result = cv2.flip(im, 1) # augment segments (flip left-right)
254
+ i = cv2.flip(im_new, 1).astype(bool)
255
+ im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug
256
+
257
+ return im, labels, segments
258
+
259
+
260
+ def cutout(im, labels, p=0.5):
261
+ # Applies image cutout augmentation https://arxiv.org/abs/1708.04552
262
+ if random.random() < p:
263
+ h, w = im.shape[:2]
264
+ scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction
265
+ for s in scales:
266
+ mask_h = random.randint(1, int(h * s)) # create random masks
267
+ mask_w = random.randint(1, int(w * s))
268
+
269
+ # box
270
+ xmin = max(0, random.randint(0, w) - mask_w // 2)
271
+ ymin = max(0, random.randint(0, h) - mask_h // 2)
272
+ xmax = min(w, xmin + mask_w)
273
+ ymax = min(h, ymin + mask_h)
274
+
275
+ # apply random color mask
276
+ im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)]
277
+
278
+ # return unobscured labels
279
+ if len(labels) and s > 0.03:
280
+ box = np.array([[xmin, ymin, xmax, ymax]], dtype=np.float32)
281
+ ioa = bbox_ioa(box, xywhn2xyxy(labels[:, 1:5], w, h))[0] # intersection over area
282
+ labels = labels[ioa < 0.60] # remove >60% obscured labels
283
+
284
+ return labels
285
+
286
+
287
+ def mixup(im, labels, im2, labels2):
288
+ # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
289
+ r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
290
+ im = (im * r + im2 * (1 - r)).astype(np.uint8)
291
+ labels = np.concatenate((labels, labels2), 0)
292
+ return im, labels
293
+
294
+
295
+ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n)
296
+ # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio
297
+ w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
298
+ w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
299
+ ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
300
+ return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
301
+
302
+
303
+ def classify_albumentations(
304
+ augment=True,
305
+ size=224,
306
+ scale=(0.08, 1.0),
307
+ ratio=(0.75, 1.0 / 0.75), # 0.75, 1.33
308
+ hflip=0.5,
309
+ vflip=0.0,
310
+ jitter=0.4,
311
+ mean=IMAGENET_MEAN,
312
+ std=IMAGENET_STD,
313
+ auto_aug=False):
314
+ # YOLOv5 classification Albumentations (optional, only used if package is installed)
315
+ prefix = colorstr('albumentations: ')
316
+ try:
317
+ import albumentations as A
318
+ from albumentations.pytorch import ToTensorV2
319
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
320
+ if augment: # Resize and crop
321
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale, ratio=ratio)]
322
+ if auto_aug:
323
+ # TODO: implement AugMix, AutoAug & RandAug in albumentation
324
+ LOGGER.info(f'{prefix}auto augmentations are currently not supported')
325
+ else:
326
+ if hflip > 0:
327
+ T += [A.HorizontalFlip(p=hflip)]
328
+ if vflip > 0:
329
+ T += [A.VerticalFlip(p=vflip)]
330
+ if jitter > 0:
331
+ color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
332
+ T += [A.ColorJitter(*color_jitter, 0)]
333
+ else: # Use fixed crop for eval set (reproducibility)
334
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
335
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
336
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
337
+ return A.Compose(T)
338
+
339
+ except ImportError: # package not installed, skip
340
+ LOGGER.warning(f'{prefix}⚠️ not found, install with `pip install albumentations` (recommended)')
341
+ except Exception as e:
342
+ LOGGER.info(f'{prefix}{e}')
343
+
344
+
345
+ def classify_transforms(size=224):
346
+ # Transforms to apply if albumentations not installed
347
+ assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
348
+ # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
349
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
350
+
351
+
352
+ class LetterBox:
353
+ # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
354
+ def __init__(self, size=(640, 640), auto=False, stride=32):
355
+ super().__init__()
356
+ self.h, self.w = (size, size) if isinstance(size, int) else size
357
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
358
+ self.stride = stride # used with auto
359
+
360
+ def __call__(self, im): # im = np.array HWC
361
+ imh, imw = im.shape[:2]
362
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
363
+ h, w = round(imh * r), round(imw * r) # resized image
364
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
365
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
366
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
367
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
368
+ return im_out
369
+
370
+
371
+ class CenterCrop:
372
+ # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
373
+ def __init__(self, size=640):
374
+ super().__init__()
375
+ self.h, self.w = (size, size) if isinstance(size, int) else size
376
+
377
+ def __call__(self, im): # im = np.array HWC
378
+ imh, imw = im.shape[:2]
379
+ m = min(imh, imw) # min dimension
380
+ top, left = (imh - m) // 2, (imw - m) // 2
381
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
382
+
383
+
384
+ class ToTensor:
385
+ # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
386
+ def __init__(self, half=False):
387
+ super().__init__()
388
+ self.half = half
389
+
390
+ def __call__(self, im): # im = np.array HWC in BGR order
391
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
392
+ im = torch.from_numpy(im) # to torch
393
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
394
+ im /= 255.0 # 0-255 to 0.0-1.0
395
+ return im
yolov9/utils/autoanchor.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ import yaml
6
+ from tqdm import tqdm
7
+
8
+ from utils import TryExcept
9
+ from utils.general import LOGGER, TQDM_BAR_FORMAT, colorstr
10
+
11
+ PREFIX = colorstr('AutoAnchor: ')
12
+
13
+
14
+ def check_anchor_order(m):
15
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
16
+ a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
17
+ da = a[-1] - a[0] # delta a
18
+ ds = m.stride[-1] - m.stride[0] # delta s
19
+ if da and (da.sign() != ds.sign()): # same order
20
+ LOGGER.info(f'{PREFIX}Reversing anchor order')
21
+ m.anchors[:] = m.anchors.flip(0)
22
+
23
+
24
+ @TryExcept(f'{PREFIX}ERROR')
25
+ def check_anchors(dataset, model, thr=4.0, imgsz=640):
26
+ # Check anchor fit to data, recompute if necessary
27
+ m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
28
+ shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
29
+ scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
30
+ wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
31
+
32
+ def metric(k): # compute metric
33
+ r = wh[:, None] / k[None]
34
+ x = torch.min(r, 1 / r).min(2)[0] # ratio metric
35
+ best = x.max(1)[0] # best_x
36
+ aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
37
+ bpr = (best > 1 / thr).float().mean() # best possible recall
38
+ return bpr, aat
39
+
40
+ stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
41
+ anchors = m.anchors.clone() * stride # current anchors
42
+ bpr, aat = metric(anchors.cpu().view(-1, 2))
43
+ s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
44
+ if bpr > 0.98: # threshold to recompute
45
+ LOGGER.info(f'{s}Current anchors are a good fit to dataset ✅')
46
+ else:
47
+ LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
48
+ na = m.anchors.numel() // 2 # number of anchors
49
+ anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
50
+ new_bpr = metric(anchors)[0]
51
+ if new_bpr > bpr: # replace anchors
52
+ anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
53
+ m.anchors[:] = anchors.clone().view_as(m.anchors)
54
+ check_anchor_order(m) # must be in pixel-space (not grid-space)
55
+ m.anchors /= stride
56
+ s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
57
+ else:
58
+ s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
59
+ LOGGER.info(s)
60
+
61
+
62
+ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
63
+ """ Creates kmeans-evolved anchors from training dataset
64
+
65
+ Arguments:
66
+ dataset: path to data.yaml, or a loaded dataset
67
+ n: number of anchors
68
+ img_size: image size used for training
69
+ thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
70
+ gen: generations to evolve anchors using genetic algorithm
71
+ verbose: print all results
72
+
73
+ Return:
74
+ k: kmeans evolved anchors
75
+
76
+ Usage:
77
+ from utils.autoanchor import *; _ = kmean_anchors()
78
+ """
79
+ from scipy.cluster.vq import kmeans
80
+
81
+ npr = np.random
82
+ thr = 1 / thr
83
+
84
+ def metric(k, wh): # compute metrics
85
+ r = wh[:, None] / k[None]
86
+ x = torch.min(r, 1 / r).min(2)[0] # ratio metric
87
+ # x = wh_iou(wh, torch.tensor(k)) # iou metric
88
+ return x, x.max(1)[0] # x, best_x
89
+
90
+ def anchor_fitness(k): # mutation fitness
91
+ _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
92
+ return (best * (best > thr).float()).mean() # fitness
93
+
94
+ def print_results(k, verbose=True):
95
+ k = k[np.argsort(k.prod(1))] # sort small to large
96
+ x, best = metric(k, wh0)
97
+ bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
98
+ s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
99
+ f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
100
+ f'past_thr={x[x > thr].mean():.3f}-mean: '
101
+ for x in k:
102
+ s += '%i,%i, ' % (round(x[0]), round(x[1]))
103
+ if verbose:
104
+ LOGGER.info(s[:-2])
105
+ return k
106
+
107
+ if isinstance(dataset, str): # *.yaml file
108
+ with open(dataset, errors='ignore') as f:
109
+ data_dict = yaml.safe_load(f) # model dict
110
+ from utils.dataloaders import LoadImagesAndLabels
111
+ dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
112
+
113
+ # Get label wh
114
+ shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
115
+ wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
116
+
117
+ # Filter
118
+ i = (wh0 < 3.0).any(1).sum()
119
+ if i:
120
+ LOGGER.info(f'{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size')
121
+ wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
122
+ # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
123
+
124
+ # Kmeans init
125
+ try:
126
+ LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
127
+ assert n <= len(wh) # apply overdetermined constraint
128
+ s = wh.std(0) # sigmas for whitening
129
+ k = kmeans(wh / s, n, iter=30)[0] * s # points
130
+ assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
131
+ except Exception:
132
+ LOGGER.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init')
133
+ k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
134
+ wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
135
+ k = print_results(k, verbose=False)
136
+
137
+ # Plot
138
+ # k, d = [None] * 20, [None] * 20
139
+ # for i in tqdm(range(1, 21)):
140
+ # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
141
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
142
+ # ax = ax.ravel()
143
+ # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
144
+ # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
145
+ # ax[0].hist(wh[wh[:, 0]<100, 0],400)
146
+ # ax[1].hist(wh[wh[:, 1]<100, 1],400)
147
+ # fig.savefig('wh.png', dpi=200)
148
+
149
+ # Evolve
150
+ f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
151
+ pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT) # progress bar
152
+ for _ in pbar:
153
+ v = np.ones(sh)
154
+ while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
155
+ v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
156
+ kg = (k.copy() * v).clip(min=2.0)
157
+ fg = anchor_fitness(kg)
158
+ if fg > f:
159
+ f, k = fg, kg.copy()
160
+ pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
161
+ if verbose:
162
+ print_results(k, verbose)
163
+
164
+ return print_results(k).astype(np.float32)
yolov9/utils/autobatch.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from utils.general import LOGGER, colorstr
7
+ from utils.torch_utils import profile
8
+
9
+
10
+ def check_train_batch_size(model, imgsz=640, amp=True):
11
+ # Check YOLOv5 training batch size
12
+ with torch.cuda.amp.autocast(amp):
13
+ return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
14
+
15
+
16
+ def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
17
+ # Automatically estimate best YOLOv5 batch size to use `fraction` of available CUDA memory
18
+ # Usage:
19
+ # import torch
20
+ # from utils.autobatch import autobatch
21
+ # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False)
22
+ # print(autobatch(model))
23
+
24
+ # Check device
25
+ prefix = colorstr('AutoBatch: ')
26
+ LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}')
27
+ device = next(model.parameters()).device # get model device
28
+ if device.type == 'cpu':
29
+ LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}')
30
+ return batch_size
31
+ if torch.backends.cudnn.benchmark:
32
+ LOGGER.info(f'{prefix} ⚠️ Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}')
33
+ return batch_size
34
+
35
+ # Inspect CUDA memory
36
+ gb = 1 << 30 # bytes to GiB (1024 ** 3)
37
+ d = str(device).upper() # 'CUDA:0'
38
+ properties = torch.cuda.get_device_properties(device) # device properties
39
+ t = properties.total_memory / gb # GiB total
40
+ r = torch.cuda.memory_reserved(device) / gb # GiB reserved
41
+ a = torch.cuda.memory_allocated(device) / gb # GiB allocated
42
+ f = t - (r + a) # GiB free
43
+ LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free')
44
+
45
+ # Profile batch sizes
46
+ batch_sizes = [1, 2, 4, 8, 16]
47
+ try:
48
+ img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
49
+ results = profile(img, model, n=3, device=device)
50
+ except Exception as e:
51
+ LOGGER.warning(f'{prefix}{e}')
52
+
53
+ # Fit a solution
54
+ y = [x[2] for x in results if x] # memory [2]
55
+ p = np.polyfit(batch_sizes[:len(y)], y, deg=1) # first degree polynomial fit
56
+ b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size)
57
+ if None in results: # some sizes failed
58
+ i = results.index(None) # first fail index
59
+ if b >= batch_sizes[i]: # y intercept above failure point
60
+ b = batch_sizes[max(i - 1, 0)] # select prior safe point
61
+ if b < 1 or b > 1024: # b outside of safe range
62
+ b = batch_size
63
+ LOGGER.warning(f'{prefix}WARNING ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
64
+
65
+ fraction = (np.polyval(p, b) + r + a) / t # actual fraction predicted
66
+ LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
67
+ return b
yolov9/utils/callbacks.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+
3
+
4
+ class Callbacks:
5
+ """"
6
+ Handles all registered callbacks for YOLOv5 Hooks
7
+ """
8
+
9
+ def __init__(self):
10
+ # Define the available callbacks
11
+ self._callbacks = {
12
+ 'on_pretrain_routine_start': [],
13
+ 'on_pretrain_routine_end': [],
14
+ 'on_train_start': [],
15
+ 'on_train_epoch_start': [],
16
+ 'on_train_batch_start': [],
17
+ 'optimizer_step': [],
18
+ 'on_before_zero_grad': [],
19
+ 'on_train_batch_end': [],
20
+ 'on_train_epoch_end': [],
21
+ 'on_val_start': [],
22
+ 'on_val_batch_start': [],
23
+ 'on_val_image_end': [],
24
+ 'on_val_batch_end': [],
25
+ 'on_val_end': [],
26
+ 'on_fit_epoch_end': [], # fit = train + val
27
+ 'on_model_save': [],
28
+ 'on_train_end': [],
29
+ 'on_params_update': [],
30
+ 'teardown': [],}
31
+ self.stop_training = False # set True to interrupt training
32
+
33
+ def register_action(self, hook, name='', callback=None):
34
+ """
35
+ Register a new action to a callback hook
36
+
37
+ Args:
38
+ hook: The callback hook name to register the action to
39
+ name: The name of the action for later reference
40
+ callback: The callback to fire
41
+ """
42
+ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
43
+ assert callable(callback), f"callback '{callback}' is not callable"
44
+ self._callbacks[hook].append({'name': name, 'callback': callback})
45
+
46
+ def get_registered_actions(self, hook=None):
47
+ """"
48
+ Returns all the registered actions by callback hook
49
+
50
+ Args:
51
+ hook: The name of the hook to check, defaults to all
52
+ """
53
+ return self._callbacks[hook] if hook else self._callbacks
54
+
55
+ def run(self, hook, *args, thread=False, **kwargs):
56
+ """
57
+ Loop through the registered actions and fire all callbacks on main thread
58
+
59
+ Args:
60
+ hook: The name of the hook to check, defaults to all
61
+ args: Arguments to receive from YOLOv5
62
+ thread: (boolean) Run callbacks in daemon thread
63
+ kwargs: Keyword Arguments to receive from YOLOv5
64
+ """
65
+
66
+ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
67
+ for logger in self._callbacks[hook]:
68
+ if thread:
69
+ threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
70
+ else:
71
+ logger['callback'](*args, **kwargs)
yolov9/utils/dataloaders.py ADDED
@@ -0,0 +1,1217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import glob
3
+ import hashlib
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ import shutil
9
+ import time
10
+ from itertools import repeat
11
+ from multiprocessing.pool import Pool, ThreadPool
12
+ from pathlib import Path
13
+ from threading import Thread
14
+ from urllib.parse import urlparse
15
+
16
+ import numpy as np
17
+ import psutil
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torchvision
21
+ import yaml
22
+ from PIL import ExifTags, Image, ImageOps
23
+ from torch.utils.data import DataLoader, Dataset, dataloader, distributed
24
+ from tqdm import tqdm
25
+
26
+ from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
27
+ letterbox, mixup, random_perspective)
28
+ from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, check_dataset, check_requirements,
29
+ check_yaml, clean_str, cv2, is_colab, is_kaggle, segments2boxes, unzip_file, xyn2xy,
30
+ xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
31
+ from utils.torch_utils import torch_distributed_zero_first
32
+
33
+ # Parameters
34
+ HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
35
+ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
36
+ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
37
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
38
+ RANK = int(os.getenv('RANK', -1))
39
+ PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
40
+
41
+ # Get orientation exif tag
42
+ for orientation in ExifTags.TAGS.keys():
43
+ if ExifTags.TAGS[orientation] == 'Orientation':
44
+ break
45
+
46
+
47
+ def get_hash(paths):
48
+ # Returns a single hash value of a list of paths (files or dirs)
49
+ size = sum(os.path.getsize(p) for p in paths if os.path.exists(p)) # sizes
50
+ h = hashlib.md5(str(size).encode()) # hash sizes
51
+ h.update(''.join(paths).encode()) # hash paths
52
+ return h.hexdigest() # return hash
53
+
54
+
55
+ def exif_size(img):
56
+ # Returns exif-corrected PIL size
57
+ s = img.size # (width, height)
58
+ with contextlib.suppress(Exception):
59
+ rotation = dict(img._getexif().items())[orientation]
60
+ if rotation in [6, 8]: # rotation 270 or 90
61
+ s = (s[1], s[0])
62
+ return s
63
+
64
+
65
+ def exif_transpose(image):
66
+ """
67
+ Transpose a PIL image accordingly if it has an EXIF Orientation tag.
68
+ Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
69
+
70
+ :param image: The image to transpose.
71
+ :return: An image.
72
+ """
73
+ exif = image.getexif()
74
+ orientation = exif.get(0x0112, 1) # default 1
75
+ if orientation > 1:
76
+ method = {
77
+ 2: Image.FLIP_LEFT_RIGHT,
78
+ 3: Image.ROTATE_180,
79
+ 4: Image.FLIP_TOP_BOTTOM,
80
+ 5: Image.TRANSPOSE,
81
+ 6: Image.ROTATE_270,
82
+ 7: Image.TRANSVERSE,
83
+ 8: Image.ROTATE_90}.get(orientation)
84
+ if method is not None:
85
+ image = image.transpose(method)
86
+ del exif[0x0112]
87
+ image.info["exif"] = exif.tobytes()
88
+ return image
89
+
90
+
91
+ def seed_worker(worker_id):
92
+ # Set dataloader worker seed https://pytorch.org/docs/stable/notes/randomness.html#dataloader
93
+ worker_seed = torch.initial_seed() % 2 ** 32
94
+ np.random.seed(worker_seed)
95
+ random.seed(worker_seed)
96
+
97
+
98
+ def create_dataloader(path,
99
+ imgsz,
100
+ batch_size,
101
+ stride,
102
+ single_cls=False,
103
+ hyp=None,
104
+ augment=False,
105
+ cache=False,
106
+ pad=0.0,
107
+ rect=False,
108
+ rank=-1,
109
+ workers=8,
110
+ image_weights=False,
111
+ close_mosaic=False,
112
+ quad=False,
113
+ min_items=0,
114
+ prefix='',
115
+ shuffle=False):
116
+ if rect and shuffle:
117
+ LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
118
+ shuffle = False
119
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
120
+ dataset = LoadImagesAndLabels(
121
+ path,
122
+ imgsz,
123
+ batch_size,
124
+ augment=augment, # augmentation
125
+ hyp=hyp, # hyperparameters
126
+ rect=rect, # rectangular batches
127
+ cache_images=cache,
128
+ single_cls=single_cls,
129
+ stride=int(stride),
130
+ pad=pad,
131
+ image_weights=image_weights,
132
+ min_items=min_items,
133
+ prefix=prefix)
134
+
135
+ batch_size = min(batch_size, len(dataset))
136
+ nd = torch.cuda.device_count() # number of CUDA devices
137
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
138
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
139
+ #loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
140
+ loader = DataLoader if image_weights or close_mosaic else InfiniteDataLoader
141
+ generator = torch.Generator()
142
+ generator.manual_seed(6148914691236517205 + RANK)
143
+ return loader(dataset,
144
+ batch_size=batch_size,
145
+ shuffle=shuffle and sampler is None,
146
+ num_workers=nw,
147
+ sampler=sampler,
148
+ pin_memory=PIN_MEMORY,
149
+ collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
150
+ worker_init_fn=seed_worker,
151
+ generator=generator), dataset
152
+
153
+
154
+ class InfiniteDataLoader(dataloader.DataLoader):
155
+ """ Dataloader that reuses workers
156
+
157
+ Uses same syntax as vanilla DataLoader
158
+ """
159
+
160
+ def __init__(self, *args, **kwargs):
161
+ super().__init__(*args, **kwargs)
162
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
163
+ self.iterator = super().__iter__()
164
+
165
+ def __len__(self):
166
+ return len(self.batch_sampler.sampler)
167
+
168
+ def __iter__(self):
169
+ for _ in range(len(self)):
170
+ yield next(self.iterator)
171
+
172
+
173
+ class _RepeatSampler:
174
+ """ Sampler that repeats forever
175
+
176
+ Args:
177
+ sampler (Sampler)
178
+ """
179
+
180
+ def __init__(self, sampler):
181
+ self.sampler = sampler
182
+
183
+ def __iter__(self):
184
+ while True:
185
+ yield from iter(self.sampler)
186
+
187
+
188
+ class LoadScreenshots:
189
+ # YOLOv5 screenshot dataloader, i.e. `python detect.py --source "screen 0 100 100 512 256"`
190
+ def __init__(self, source, img_size=640, stride=32, auto=True, transforms=None):
191
+ # source = [screen_number left top width height] (pixels)
192
+ check_requirements('mss')
193
+ import mss
194
+
195
+ source, *params = source.split()
196
+ self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
197
+ if len(params) == 1:
198
+ self.screen = int(params[0])
199
+ elif len(params) == 4:
200
+ left, top, width, height = (int(x) for x in params)
201
+ elif len(params) == 5:
202
+ self.screen, left, top, width, height = (int(x) for x in params)
203
+ self.img_size = img_size
204
+ self.stride = stride
205
+ self.transforms = transforms
206
+ self.auto = auto
207
+ self.mode = 'stream'
208
+ self.frame = 0
209
+ self.sct = mss.mss()
210
+
211
+ # Parse monitor shape
212
+ monitor = self.sct.monitors[self.screen]
213
+ self.top = monitor["top"] if top is None else (monitor["top"] + top)
214
+ self.left = monitor["left"] if left is None else (monitor["left"] + left)
215
+ self.width = width or monitor["width"]
216
+ self.height = height or monitor["height"]
217
+ self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
218
+
219
+ def __iter__(self):
220
+ return self
221
+
222
+ def __next__(self):
223
+ # mss screen capture: get raw pixels from the screen as np array
224
+ im0 = np.array(self.sct.grab(self.monitor))[:, :, :3] # [:, :, :3] BGRA to BGR
225
+ s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
226
+
227
+ if self.transforms:
228
+ im = self.transforms(im0) # transforms
229
+ else:
230
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
231
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
232
+ im = np.ascontiguousarray(im) # contiguous
233
+ self.frame += 1
234
+ return str(self.screen), im, im0, None, s # screen, img, original img, im0s, s
235
+
236
+
237
+ class LoadImages:
238
+ # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
239
+ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
240
+ files = []
241
+ for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
242
+ p = str(Path(p).resolve())
243
+ if '*' in p:
244
+ files.extend(sorted(glob.glob(p, recursive=True))) # glob
245
+ elif os.path.isdir(p):
246
+ files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
247
+ elif os.path.isfile(p):
248
+ files.append(p) # files
249
+ else:
250
+ raise FileNotFoundError(f'{p} does not exist')
251
+
252
+ images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
253
+ videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
254
+ ni, nv = len(images), len(videos)
255
+
256
+ self.img_size = img_size
257
+ self.stride = stride
258
+ self.files = images + videos
259
+ self.nf = ni + nv # number of files
260
+ self.video_flag = [False] * ni + [True] * nv
261
+ self.mode = 'image'
262
+ self.auto = auto
263
+ self.transforms = transforms # optional
264
+ self.vid_stride = vid_stride # video frame-rate stride
265
+ if any(videos):
266
+ self._new_video(videos[0]) # new video
267
+ else:
268
+ self.cap = None
269
+ assert self.nf > 0, f'No images or videos found in {p}. ' \
270
+ f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
271
+
272
+ def __iter__(self):
273
+ self.count = 0
274
+ return self
275
+
276
+ def __next__(self):
277
+ if self.count == self.nf:
278
+ raise StopIteration
279
+ path = self.files[self.count]
280
+
281
+ if self.video_flag[self.count]:
282
+ # Read video
283
+ self.mode = 'video'
284
+ for _ in range(self.vid_stride):
285
+ self.cap.grab()
286
+ ret_val, im0 = self.cap.retrieve()
287
+ while not ret_val:
288
+ self.count += 1
289
+ self.cap.release()
290
+ if self.count == self.nf: # last video
291
+ raise StopIteration
292
+ path = self.files[self.count]
293
+ self._new_video(path)
294
+ ret_val, im0 = self.cap.read()
295
+
296
+ self.frame += 1
297
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
298
+ s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
299
+
300
+ else:
301
+ # Read image
302
+ self.count += 1
303
+ im0 = cv2.imread(path) # BGR
304
+ assert im0 is not None, f'Image Not Found {path}'
305
+ s = f'image {self.count}/{self.nf} {path}: '
306
+
307
+ if self.transforms:
308
+ im = self.transforms(im0) # transforms
309
+ else:
310
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
311
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
312
+ im = np.ascontiguousarray(im) # contiguous
313
+
314
+ return path, im, im0, self.cap, s
315
+
316
+ def _new_video(self, path):
317
+ # Create a new video capture object
318
+ self.frame = 0
319
+ self.cap = cv2.VideoCapture(path)
320
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
321
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
322
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
323
+
324
+ def _cv2_rotate(self, im):
325
+ # Rotate a cv2 video manually
326
+ if self.orientation == 0:
327
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
328
+ elif self.orientation == 180:
329
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
330
+ elif self.orientation == 90:
331
+ return cv2.rotate(im, cv2.ROTATE_180)
332
+ return im
333
+
334
+ def __len__(self):
335
+ return self.nf # number of files
336
+
337
+
338
+ class LoadStreams:
339
+ # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
340
+ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
341
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
342
+ self.mode = 'stream'
343
+ self.img_size = img_size
344
+ self.stride = stride
345
+ self.vid_stride = vid_stride # video frame-rate stride
346
+ sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
347
+ n = len(sources)
348
+ self.sources = [clean_str(x) for x in sources] # clean source names for later
349
+ self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
350
+ for i, s in enumerate(sources): # index, source
351
+ # Start thread to read frames from video stream
352
+ st = f'{i + 1}/{n}: {s}... '
353
+ if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
354
+ # YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/Zgi9g1ksQHc'
355
+ check_requirements(('pafy', 'youtube_dl==2020.12.2'))
356
+ import pafy
357
+ s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
358
+ s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
359
+ if s == 0:
360
+ assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
361
+ assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
362
+ cap = cv2.VideoCapture(s)
363
+ assert cap.isOpened(), f'{st}Failed to open {s}'
364
+ w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
365
+ h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
366
+ fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
367
+ self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
368
+ self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
369
+
370
+ _, self.imgs[i] = cap.read() # guarantee first frame
371
+ self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
372
+ LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
373
+ self.threads[i].start()
374
+ LOGGER.info('') # newline
375
+
376
+ # check for common shapes
377
+ s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
378
+ self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
379
+ self.auto = auto and self.rect
380
+ self.transforms = transforms # optional
381
+ if not self.rect:
382
+ LOGGER.warning('WARNING ⚠️ Stream shapes differ. For optimal performance supply similarly-shaped streams.')
383
+
384
+ def update(self, i, cap, stream):
385
+ # Read stream `i` frames in daemon thread
386
+ n, f = 0, self.frames[i] # frame number, frame array
387
+ while cap.isOpened() and n < f:
388
+ n += 1
389
+ cap.grab() # .read() = .grab() followed by .retrieve()
390
+ if n % self.vid_stride == 0:
391
+ success, im = cap.retrieve()
392
+ if success:
393
+ self.imgs[i] = im
394
+ else:
395
+ LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
396
+ self.imgs[i] = np.zeros_like(self.imgs[i])
397
+ cap.open(stream) # re-open stream if signal was lost
398
+ time.sleep(0.0) # wait time
399
+
400
+ def __iter__(self):
401
+ self.count = -1
402
+ return self
403
+
404
+ def __next__(self):
405
+ self.count += 1
406
+ if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
407
+ cv2.destroyAllWindows()
408
+ raise StopIteration
409
+
410
+ im0 = self.imgs.copy()
411
+ if self.transforms:
412
+ im = np.stack([self.transforms(x) for x in im0]) # transforms
413
+ else:
414
+ im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
415
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
416
+ im = np.ascontiguousarray(im) # contiguous
417
+
418
+ return self.sources, im, im0, None, ''
419
+
420
+ def __len__(self):
421
+ return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
422
+
423
+
424
+ def img2label_paths(img_paths):
425
+ # Define label paths as a function of image paths
426
+ sa, sb = f'{os.sep}images{os.sep}', f'{os.sep}labels{os.sep}' # /images/, /labels/ substrings
427
+ return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]
428
+
429
+
430
+ class LoadImagesAndLabels(Dataset):
431
+ # YOLOv5 train_loader/val_loader, loads images and labels for training and validation
432
+ cache_version = 0.6 # dataset labels *.cache version
433
+ rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
434
+
435
+ def __init__(self,
436
+ path,
437
+ img_size=640,
438
+ batch_size=16,
439
+ augment=False,
440
+ hyp=None,
441
+ rect=False,
442
+ image_weights=False,
443
+ cache_images=False,
444
+ single_cls=False,
445
+ stride=32,
446
+ pad=0.0,
447
+ min_items=0,
448
+ prefix=''):
449
+ self.img_size = img_size
450
+ self.augment = augment
451
+ self.hyp = hyp
452
+ self.image_weights = image_weights
453
+ self.rect = False if image_weights else rect
454
+ self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
455
+ self.mosaic_border = [-img_size // 2, -img_size // 2]
456
+ self.stride = stride
457
+ self.path = path
458
+ self.albumentations = Albumentations(size=img_size) if augment else None
459
+
460
+ try:
461
+ f = [] # image files
462
+ for p in path if isinstance(path, list) else [path]:
463
+ p = Path(p) # os-agnostic
464
+ if p.is_dir(): # dir
465
+ f += glob.glob(str(p / '**' / '*.*'), recursive=True)
466
+ # f = list(p.rglob('*.*')) # pathlib
467
+ elif p.is_file(): # file
468
+ with open(p) as t:
469
+ t = t.read().strip().splitlines()
470
+ parent = str(p.parent) + os.sep
471
+ f += [x.replace('./', parent, 1) if x.startswith('./') else x for x in t] # to global path
472
+ # f += [p.parent / x.lstrip(os.sep) for x in t] # to global path (pathlib)
473
+ else:
474
+ raise FileNotFoundError(f'{prefix}{p} does not exist')
475
+ self.im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
476
+ # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
477
+ assert self.im_files, f'{prefix}No images found'
478
+ except Exception as e:
479
+ raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}') from e
480
+
481
+ # Check cache
482
+ self.label_files = img2label_paths(self.im_files) # labels
483
+ cache_path = (p if p.is_file() else Path(self.label_files[0]).parent).with_suffix('.cache')
484
+ try:
485
+ cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
486
+ assert cache['version'] == self.cache_version # matches current version
487
+ assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash
488
+ except Exception:
489
+ cache, exists = self.cache_labels(cache_path, prefix), False # run cache ops
490
+
491
+ # Display cache
492
+ nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total
493
+ if exists and LOCAL_RANK in {-1, 0}:
494
+ d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
495
+ tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results
496
+ if cache['msgs']:
497
+ LOGGER.info('\n'.join(cache['msgs'])) # display warnings
498
+ assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
499
+
500
+ # Read cache
501
+ [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
502
+ labels, shapes, self.segments = zip(*cache.values())
503
+ nl = len(np.concatenate(labels, 0)) # number of labels
504
+ assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
505
+ self.labels = list(labels)
506
+ self.shapes = np.array(shapes)
507
+ self.im_files = list(cache.keys()) # update
508
+ self.label_files = img2label_paths(cache.keys()) # update
509
+
510
+ # Filter images
511
+ if min_items:
512
+ include = np.array([len(x) >= min_items for x in self.labels]).nonzero()[0].astype(int)
513
+ LOGGER.info(f'{prefix}{n - len(include)}/{n} images filtered from dataset')
514
+ self.im_files = [self.im_files[i] for i in include]
515
+ self.label_files = [self.label_files[i] for i in include]
516
+ self.labels = [self.labels[i] for i in include]
517
+ self.segments = [self.segments[i] for i in include]
518
+ self.shapes = self.shapes[include] # wh
519
+
520
+ # Create indices
521
+ n = len(self.shapes) # number of images
522
+ bi = np.floor(np.arange(n) / batch_size).astype(int) # batch index
523
+ nb = bi[-1] + 1 # number of batches
524
+ self.batch = bi # batch index of image
525
+ self.n = n
526
+ self.indices = range(n)
527
+
528
+ # Update labels
529
+ include_class = [] # filter labels to include only these classes (optional)
530
+ include_class_array = np.array(include_class).reshape(1, -1)
531
+ for i, (label, segment) in enumerate(zip(self.labels, self.segments)):
532
+ if include_class:
533
+ j = (label[:, 0:1] == include_class_array).any(1)
534
+ self.labels[i] = label[j]
535
+ if segment:
536
+ self.segments[i] = segment[j]
537
+ if single_cls: # single-class training, merge all classes into 0
538
+ self.labels[i][:, 0] = 0
539
+
540
+ # Rectangular Training
541
+ if self.rect:
542
+ # Sort by aspect ratio
543
+ s = self.shapes # wh
544
+ ar = s[:, 1] / s[:, 0] # aspect ratio
545
+ irect = ar.argsort()
546
+ self.im_files = [self.im_files[i] for i in irect]
547
+ self.label_files = [self.label_files[i] for i in irect]
548
+ self.labels = [self.labels[i] for i in irect]
549
+ self.segments = [self.segments[i] for i in irect]
550
+ self.shapes = s[irect] # wh
551
+ ar = ar[irect]
552
+
553
+ # Set training image shapes
554
+ shapes = [[1, 1]] * nb
555
+ for i in range(nb):
556
+ ari = ar[bi == i]
557
+ mini, maxi = ari.min(), ari.max()
558
+ if maxi < 1:
559
+ shapes[i] = [maxi, 1]
560
+ elif mini > 1:
561
+ shapes[i] = [1, 1 / mini]
562
+
563
+ self.batch_shapes = np.ceil(np.array(shapes) * img_size / stride + pad).astype(int) * stride
564
+
565
+ # Cache images into RAM/disk for faster training
566
+ if cache_images == 'ram' and not self.check_cache_ram(prefix=prefix):
567
+ cache_images = False
568
+ self.ims = [None] * n
569
+ self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
570
+ if cache_images:
571
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
572
+ self.im_hw0, self.im_hw = [None] * n, [None] * n
573
+ fcn = self.cache_images_to_disk if cache_images == 'disk' else self.load_image
574
+ results = ThreadPool(NUM_THREADS).imap(fcn, range(n))
575
+ pbar = tqdm(enumerate(results), total=n, bar_format=TQDM_BAR_FORMAT, disable=LOCAL_RANK > 0)
576
+ for i, x in pbar:
577
+ if cache_images == 'disk':
578
+ b += self.npy_files[i].stat().st_size
579
+ else: # 'ram'
580
+ self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
581
+ b += self.ims[i].nbytes
582
+ pbar.desc = f'{prefix}Caching images ({b / gb:.1f}GB {cache_images})'
583
+ pbar.close()
584
+
585
+ def check_cache_ram(self, safety_margin=0.1, prefix=''):
586
+ # Check image caching requirements vs available memory
587
+ b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
588
+ n = min(self.n, 30) # extrapolate from 30 random images
589
+ for _ in range(n):
590
+ im = cv2.imread(random.choice(self.im_files)) # sample image
591
+ ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
592
+ b += im.nbytes * ratio ** 2
593
+ mem_required = b * self.n / n # GB required to cache dataset into RAM
594
+ mem = psutil.virtual_memory()
595
+ cache = mem_required * (1 + safety_margin) < mem.available # to cache or not to cache, that is the question
596
+ if not cache:
597
+ LOGGER.info(f"{prefix}{mem_required / gb:.1f}GB RAM required, "
598
+ f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, "
599
+ f"{'caching images ✅' if cache else 'not caching images ⚠️'}")
600
+ return cache
601
+
602
+ def cache_labels(self, path=Path('./labels.cache'), prefix=''):
603
+ # Cache dataset labels, check images and read shapes
604
+ x = {} # dict
605
+ nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
606
+ desc = f"{prefix}Scanning {path.parent / path.stem}..."
607
+ with Pool(NUM_THREADS) as pool:
608
+ pbar = tqdm(pool.imap(verify_image_label, zip(self.im_files, self.label_files, repeat(prefix))),
609
+ desc=desc,
610
+ total=len(self.im_files),
611
+ bar_format=TQDM_BAR_FORMAT)
612
+ for im_file, lb, shape, segments, nm_f, nf_f, ne_f, nc_f, msg in pbar:
613
+ nm += nm_f
614
+ nf += nf_f
615
+ ne += ne_f
616
+ nc += nc_f
617
+ if im_file:
618
+ x[im_file] = [lb, shape, segments]
619
+ if msg:
620
+ msgs.append(msg)
621
+ pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
622
+
623
+ pbar.close()
624
+ if msgs:
625
+ LOGGER.info('\n'.join(msgs))
626
+ if nf == 0:
627
+ LOGGER.warning(f'{prefix}WARNING ⚠️ No labels found in {path}. {HELP_URL}')
628
+ x['hash'] = get_hash(self.label_files + self.im_files)
629
+ x['results'] = nf, nm, ne, nc, len(self.im_files)
630
+ x['msgs'] = msgs # warnings
631
+ x['version'] = self.cache_version # cache version
632
+ try:
633
+ np.save(path, x) # save cache for next time
634
+ path.with_suffix('.cache.npy').rename(path) # remove .npy suffix
635
+ LOGGER.info(f'{prefix}New cache created: {path}')
636
+ except Exception as e:
637
+ LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable: {e}') # not writeable
638
+ return x
639
+
640
+ def __len__(self):
641
+ return len(self.im_files)
642
+
643
+ # def __iter__(self):
644
+ # self.count = -1
645
+ # print('ran dataset iter')
646
+ # #self.shuffled_vector = np.random.permutation(self.nF) if self.augment else np.arange(self.nF)
647
+ # return self
648
+
649
+ def __getitem__(self, index):
650
+ index = self.indices[index] # linear, shuffled, or image_weights
651
+
652
+ hyp = self.hyp
653
+ mosaic = self.mosaic and random.random() < hyp['mosaic']
654
+ if mosaic:
655
+ # Load mosaic
656
+ img, labels = self.load_mosaic(index)
657
+ shapes = None
658
+
659
+ # MixUp augmentation
660
+ if random.random() < hyp['mixup']:
661
+ img, labels = mixup(img, labels, *self.load_mosaic(random.randint(0, self.n - 1)))
662
+
663
+ else:
664
+ # Load image
665
+ img, (h0, w0), (h, w) = self.load_image(index)
666
+
667
+ # Letterbox
668
+ shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
669
+ img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
670
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
671
+
672
+ labels = self.labels[index].copy()
673
+ if labels.size: # normalized xywh to pixel xyxy format
674
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
675
+
676
+ if self.augment:
677
+ img, labels = random_perspective(img,
678
+ labels,
679
+ degrees=hyp['degrees'],
680
+ translate=hyp['translate'],
681
+ scale=hyp['scale'],
682
+ shear=hyp['shear'],
683
+ perspective=hyp['perspective'])
684
+
685
+ nl = len(labels) # number of labels
686
+ if nl:
687
+ labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1E-3)
688
+
689
+ if self.augment:
690
+ # Albumentations
691
+ img, labels = self.albumentations(img, labels)
692
+ nl = len(labels) # update after albumentations
693
+
694
+ # HSV color-space
695
+ augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
696
+
697
+ # Flip up-down
698
+ if random.random() < hyp['flipud']:
699
+ img = np.flipud(img)
700
+ if nl:
701
+ labels[:, 2] = 1 - labels[:, 2]
702
+
703
+ # Flip left-right
704
+ if random.random() < hyp['fliplr']:
705
+ img = np.fliplr(img)
706
+ if nl:
707
+ labels[:, 1] = 1 - labels[:, 1]
708
+
709
+ # Cutouts
710
+ # labels = cutout(img, labels, p=0.5)
711
+ # nl = len(labels) # update after cutout
712
+
713
+ labels_out = torch.zeros((nl, 6))
714
+ if nl:
715
+ labels_out[:, 1:] = torch.from_numpy(labels)
716
+
717
+ # Convert
718
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
719
+ img = np.ascontiguousarray(img)
720
+
721
+ return torch.from_numpy(img), labels_out, self.im_files[index], shapes
722
+
723
+ def load_image(self, i):
724
+ # Loads 1 image from dataset index 'i', returns (im, original hw, resized hw)
725
+ im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i],
726
+ if im is None: # not cached in RAM
727
+ if fn.exists(): # load npy
728
+ im = np.load(fn)
729
+ else: # read image
730
+ im = cv2.imread(f) # BGR
731
+ assert im is not None, f'Image Not Found {f}'
732
+ h0, w0 = im.shape[:2] # orig hw
733
+ r = self.img_size / max(h0, w0) # ratio
734
+ if r != 1: # if sizes are not equal
735
+ interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
736
+ im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
737
+ return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
738
+ return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
739
+
740
+ def cache_images_to_disk(self, i):
741
+ # Saves an image as an *.npy file for faster loading
742
+ f = self.npy_files[i]
743
+ if not f.exists():
744
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
745
+
746
+ def load_mosaic(self, index):
747
+ # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
748
+ labels4, segments4 = [], []
749
+ s = self.img_size
750
+ yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
751
+ indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
752
+ random.shuffle(indices)
753
+ for i, index in enumerate(indices):
754
+ # Load image
755
+ img, _, (h, w) = self.load_image(index)
756
+
757
+ # place img in img4
758
+ if i == 0: # top left
759
+ img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
760
+ x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
761
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
762
+ elif i == 1: # top right
763
+ x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
764
+ x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
765
+ elif i == 2: # bottom left
766
+ x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
767
+ x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
768
+ elif i == 3: # bottom right
769
+ x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
770
+ x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
771
+
772
+ img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
773
+ padw = x1a - x1b
774
+ padh = y1a - y1b
775
+
776
+ # Labels
777
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
778
+ if labels.size:
779
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
780
+ segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
781
+ labels4.append(labels)
782
+ segments4.extend(segments)
783
+
784
+ # Concat/clip labels
785
+ labels4 = np.concatenate(labels4, 0)
786
+ for x in (labels4[:, 1:], *segments4):
787
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
788
+ # img4, labels4 = replicate(img4, labels4) # replicate
789
+
790
+ # Augment
791
+ img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
792
+ img4, labels4 = random_perspective(img4,
793
+ labels4,
794
+ segments4,
795
+ degrees=self.hyp['degrees'],
796
+ translate=self.hyp['translate'],
797
+ scale=self.hyp['scale'],
798
+ shear=self.hyp['shear'],
799
+ perspective=self.hyp['perspective'],
800
+ border=self.mosaic_border) # border to remove
801
+
802
+ return img4, labels4
803
+
804
+ def load_mosaic9(self, index):
805
+ # YOLOv5 9-mosaic loader. Loads 1 image + 8 random images into a 9-image mosaic
806
+ labels9, segments9 = [], []
807
+ s = self.img_size
808
+ indices = [index] + random.choices(self.indices, k=8) # 8 additional image indices
809
+ random.shuffle(indices)
810
+ hp, wp = -1, -1 # height, width previous
811
+ for i, index in enumerate(indices):
812
+ # Load image
813
+ img, _, (h, w) = self.load_image(index)
814
+
815
+ # place img in img9
816
+ if i == 0: # center
817
+ img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
818
+ h0, w0 = h, w
819
+ c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates
820
+ elif i == 1: # top
821
+ c = s, s - h, s + w, s
822
+ elif i == 2: # top right
823
+ c = s + wp, s - h, s + wp + w, s
824
+ elif i == 3: # right
825
+ c = s + w0, s, s + w0 + w, s + h
826
+ elif i == 4: # bottom right
827
+ c = s + w0, s + hp, s + w0 + w, s + hp + h
828
+ elif i == 5: # bottom
829
+ c = s + w0 - w, s + h0, s + w0, s + h0 + h
830
+ elif i == 6: # bottom left
831
+ c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h
832
+ elif i == 7: # left
833
+ c = s - w, s + h0 - h, s, s + h0
834
+ elif i == 8: # top left
835
+ c = s - w, s + h0 - hp - h, s, s + h0 - hp
836
+
837
+ padx, pady = c[:2]
838
+ x1, y1, x2, y2 = (max(x, 0) for x in c) # allocate coords
839
+
840
+ # Labels
841
+ labels, segments = self.labels[index].copy(), self.segments[index].copy()
842
+ if labels.size:
843
+ labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padx, pady) # normalized xywh to pixel xyxy format
844
+ segments = [xyn2xy(x, w, h, padx, pady) for x in segments]
845
+ labels9.append(labels)
846
+ segments9.extend(segments)
847
+
848
+ # Image
849
+ img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax]
850
+ hp, wp = h, w # height, width previous
851
+
852
+ # Offset
853
+ yc, xc = (int(random.uniform(0, s)) for _ in self.mosaic_border) # mosaic center x, y
854
+ img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s]
855
+
856
+ # Concat/clip labels
857
+ labels9 = np.concatenate(labels9, 0)
858
+ labels9[:, [1, 3]] -= xc
859
+ labels9[:, [2, 4]] -= yc
860
+ c = np.array([xc, yc]) # centers
861
+ segments9 = [x - c for x in segments9]
862
+
863
+ for x in (labels9[:, 1:], *segments9):
864
+ np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
865
+ # img9, labels9 = replicate(img9, labels9) # replicate
866
+
867
+ # Augment
868
+ img9, labels9, segments9 = copy_paste(img9, labels9, segments9, p=self.hyp['copy_paste'])
869
+ img9, labels9 = random_perspective(img9,
870
+ labels9,
871
+ segments9,
872
+ degrees=self.hyp['degrees'],
873
+ translate=self.hyp['translate'],
874
+ scale=self.hyp['scale'],
875
+ shear=self.hyp['shear'],
876
+ perspective=self.hyp['perspective'],
877
+ border=self.mosaic_border) # border to remove
878
+
879
+ return img9, labels9
880
+
881
+ @staticmethod
882
+ def collate_fn(batch):
883
+ im, label, path, shapes = zip(*batch) # transposed
884
+ for i, lb in enumerate(label):
885
+ lb[:, 0] = i # add target image index for build_targets()
886
+ return torch.stack(im, 0), torch.cat(label, 0), path, shapes
887
+
888
+ @staticmethod
889
+ def collate_fn4(batch):
890
+ im, label, path, shapes = zip(*batch) # transposed
891
+ n = len(shapes) // 4
892
+ im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
893
+
894
+ ho = torch.tensor([[0.0, 0, 0, 1, 0, 0]])
895
+ wo = torch.tensor([[0.0, 0, 1, 0, 0, 0]])
896
+ s = torch.tensor([[1, 1, 0.5, 0.5, 0.5, 0.5]]) # scale
897
+ for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
898
+ i *= 4
899
+ if random.random() < 0.5:
900
+ im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
901
+ align_corners=False)[0].type(im[i].type())
902
+ lb = label[i]
903
+ else:
904
+ im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
905
+ lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
906
+ im4.append(im1)
907
+ label4.append(lb)
908
+
909
+ for i, lb in enumerate(label4):
910
+ lb[:, 0] = i # add target image index for build_targets()
911
+
912
+ return torch.stack(im4, 0), torch.cat(label4, 0), path4, shapes4
913
+
914
+
915
+ # Ancillary functions --------------------------------------------------------------------------------------------------
916
+ def flatten_recursive(path=DATASETS_DIR / 'coco128'):
917
+ # Flatten a recursive directory by bringing all files to top level
918
+ new_path = Path(f'{str(path)}_flat')
919
+ if os.path.exists(new_path):
920
+ shutil.rmtree(new_path) # delete output folder
921
+ os.makedirs(new_path) # make new output folder
922
+ for file in tqdm(glob.glob(f'{str(Path(path))}/**/*.*', recursive=True)):
923
+ shutil.copyfile(file, new_path / Path(file).name)
924
+
925
+
926
+ def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
927
+ # Convert detection dataset into classification dataset, with one directory per class
928
+ path = Path(path) # images dir
929
+ shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
930
+ files = list(path.rglob('*.*'))
931
+ n = len(files) # number of files
932
+ for im_file in tqdm(files, total=n):
933
+ if im_file.suffix[1:] in IMG_FORMATS:
934
+ # image
935
+ im = cv2.imread(str(im_file))[..., ::-1] # BGR to RGB
936
+ h, w = im.shape[:2]
937
+
938
+ # labels
939
+ lb_file = Path(img2label_paths([str(im_file)])[0])
940
+ if Path(lb_file).exists():
941
+ with open(lb_file) as f:
942
+ lb = np.array([x.split() for x in f.read().strip().splitlines()], dtype=np.float32) # labels
943
+
944
+ for j, x in enumerate(lb):
945
+ c = int(x[0]) # class
946
+ f = (path / 'classifier') / f'{c}' / f'{path.stem}_{im_file.stem}_{j}.jpg' # new filename
947
+ if not f.parent.is_dir():
948
+ f.parent.mkdir(parents=True)
949
+
950
+ b = x[1:] * [w, h, w, h] # box
951
+ # b[2:] = b[2:].max() # rectangle to square
952
+ b[2:] = b[2:] * 1.2 + 3 # pad
953
+ b = xywh2xyxy(b.reshape(-1, 4)).ravel().astype(int)
954
+
955
+ b[[0, 2]] = np.clip(b[[0, 2]], 0, w) # clip boxes outside of image
956
+ b[[1, 3]] = np.clip(b[[1, 3]], 0, h)
957
+ assert cv2.imwrite(str(f), im[b[1]:b[3], b[0]:b[2]]), f'box failure in {f}'
958
+
959
+
960
+ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), annotated_only=False):
961
+ """ Autosplit a dataset into train/val/test splits and save path/autosplit_*.txt files
962
+ Usage: from utils.dataloaders import *; autosplit()
963
+ Arguments
964
+ path: Path to images directory
965
+ weights: Train, val, test weights (list, tuple)
966
+ annotated_only: Only use images with an annotated txt file
967
+ """
968
+ path = Path(path) # images dir
969
+ files = sorted(x for x in path.rglob('*.*') if x.suffix[1:].lower() in IMG_FORMATS) # image files only
970
+ n = len(files) # number of files
971
+ random.seed(0) # for reproducibility
972
+ indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
973
+
974
+ txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
975
+ for x in txt:
976
+ if (path.parent / x).exists():
977
+ (path.parent / x).unlink() # remove existing
978
+
979
+ print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
980
+ for i, img in tqdm(zip(indices, files), total=n):
981
+ if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
982
+ with open(path.parent / txt[i], 'a') as f:
983
+ f.write(f'./{img.relative_to(path.parent).as_posix()}' + '\n') # add image to txt file
984
+
985
+
986
+ def verify_image_label(args):
987
+ # Verify one image-label pair
988
+ im_file, lb_file, prefix = args
989
+ nm, nf, ne, nc, msg, segments = 0, 0, 0, 0, '', [] # number (missing, found, empty, corrupt), message, segments
990
+ try:
991
+ # verify images
992
+ im = Image.open(im_file)
993
+ im.verify() # PIL verify
994
+ shape = exif_size(im) # image size
995
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
996
+ assert im.format.lower() in IMG_FORMATS, f'invalid image format {im.format}'
997
+ if im.format.lower() in ('jpg', 'jpeg'):
998
+ with open(im_file, 'rb') as f:
999
+ f.seek(-2, 2)
1000
+ if f.read() != b'\xff\xd9': # corrupt JPEG
1001
+ ImageOps.exif_transpose(Image.open(im_file)).save(im_file, 'JPEG', subsampling=0, quality=100)
1002
+ msg = f'{prefix}WARNING ⚠️ {im_file}: corrupt JPEG restored and saved'
1003
+
1004
+ # verify labels
1005
+ if os.path.isfile(lb_file):
1006
+ nf = 1 # label found
1007
+ with open(lb_file) as f:
1008
+ lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
1009
+ if any(len(x) > 6 for x in lb): # is segment
1010
+ classes = np.array([x[0] for x in lb], dtype=np.float32)
1011
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
1012
+ lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
1013
+ lb = np.array(lb, dtype=np.float32)
1014
+ nl = len(lb)
1015
+ if nl:
1016
+ assert lb.shape[1] == 5, f'labels require 5 columns, {lb.shape[1]} columns detected'
1017
+ assert (lb >= 0).all(), f'negative label values {lb[lb < 0]}'
1018
+ assert (lb[:, 1:] <= 1).all(), f'non-normalized or out of bounds coordinates {lb[:, 1:][lb[:, 1:] > 1]}'
1019
+ _, i = np.unique(lb, axis=0, return_index=True)
1020
+ if len(i) < nl: # duplicate row check
1021
+ lb = lb[i] # remove duplicates
1022
+ if segments:
1023
+ segments = [segments[x] for x in i]
1024
+ msg = f'{prefix}WARNING ⚠️ {im_file}: {nl - len(i)} duplicate labels removed'
1025
+ else:
1026
+ ne = 1 # label empty
1027
+ lb = np.zeros((0, 5), dtype=np.float32)
1028
+ else:
1029
+ nm = 1 # label missing
1030
+ lb = np.zeros((0, 5), dtype=np.float32)
1031
+ return im_file, lb, shape, segments, nm, nf, ne, nc, msg
1032
+ except Exception as e:
1033
+ nc = 1
1034
+ msg = f'{prefix}WARNING ⚠️ {im_file}: ignoring corrupt image/label: {e}'
1035
+ return [None, None, None, None, nm, nf, ne, nc, msg]
1036
+
1037
+
1038
+ class HUBDatasetStats():
1039
+ """ Class for generating HUB dataset JSON and `-hub` dataset directory
1040
+
1041
+ Arguments
1042
+ path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
1043
+ autodownload: Attempt to download dataset if not found locally
1044
+
1045
+ Usage
1046
+ from utils.dataloaders import HUBDatasetStats
1047
+ stats = HUBDatasetStats('coco128.yaml', autodownload=True) # usage 1
1048
+ stats = HUBDatasetStats('path/to/coco128.zip') # usage 2
1049
+ stats.get_json(save=False)
1050
+ stats.process_images()
1051
+ """
1052
+
1053
+ def __init__(self, path='coco128.yaml', autodownload=False):
1054
+ # Initialize class
1055
+ zipped, data_dir, yaml_path = self._unzip(Path(path))
1056
+ try:
1057
+ with open(check_yaml(yaml_path), errors='ignore') as f:
1058
+ data = yaml.safe_load(f) # data dict
1059
+ if zipped:
1060
+ data['path'] = data_dir
1061
+ except Exception as e:
1062
+ raise Exception("error/HUB/dataset_stats/yaml_load") from e
1063
+
1064
+ check_dataset(data, autodownload) # download dataset if missing
1065
+ self.hub_dir = Path(data['path'] + '-hub')
1066
+ self.im_dir = self.hub_dir / 'images'
1067
+ self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
1068
+ self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
1069
+ self.data = data
1070
+
1071
+ @staticmethod
1072
+ def _find_yaml(dir):
1073
+ # Return data.yaml file
1074
+ files = list(dir.glob('*.yaml')) or list(dir.rglob('*.yaml')) # try root level first and then recursive
1075
+ assert files, f'No *.yaml file found in {dir}'
1076
+ if len(files) > 1:
1077
+ files = [f for f in files if f.stem == dir.stem] # prefer *.yaml files that match dir name
1078
+ assert files, f'Multiple *.yaml files found in {dir}, only 1 *.yaml file allowed'
1079
+ assert len(files) == 1, f'Multiple *.yaml files found: {files}, only 1 *.yaml file allowed in {dir}'
1080
+ return files[0]
1081
+
1082
+ def _unzip(self, path):
1083
+ # Unzip data.zip
1084
+ if not str(path).endswith('.zip'): # path is data.yaml
1085
+ return False, None, path
1086
+ assert Path(path).is_file(), f'Error unzipping {path}, file not found'
1087
+ unzip_file(path, path=path.parent)
1088
+ dir = path.with_suffix('') # dataset directory == zip name
1089
+ assert dir.is_dir(), f'Error unzipping {path}, {dir} not found. path/to/abc.zip MUST unzip to path/to/abc/'
1090
+ return True, str(dir), self._find_yaml(dir) # zipped, data_dir, yaml_path
1091
+
1092
+ def _hub_ops(self, f, max_dim=1920):
1093
+ # HUB ops for 1 image 'f': resize and save at reduced quality in /dataset-hub for web/app viewing
1094
+ f_new = self.im_dir / Path(f).name # dataset-hub image filename
1095
+ try: # use PIL
1096
+ im = Image.open(f)
1097
+ r = max_dim / max(im.height, im.width) # ratio
1098
+ if r < 1.0: # image too large
1099
+ im = im.resize((int(im.width * r), int(im.height * r)))
1100
+ im.save(f_new, 'JPEG', quality=50, optimize=True) # save
1101
+ except Exception as e: # use OpenCV
1102
+ LOGGER.info(f'WARNING ⚠️ HUB ops PIL failure {f}: {e}')
1103
+ im = cv2.imread(f)
1104
+ im_height, im_width = im.shape[:2]
1105
+ r = max_dim / max(im_height, im_width) # ratio
1106
+ if r < 1.0: # image too large
1107
+ im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
1108
+ cv2.imwrite(str(f_new), im)
1109
+
1110
+ def get_json(self, save=False, verbose=False):
1111
+ # Return dataset JSON for Ultralytics HUB
1112
+ def _round(labels):
1113
+ # Update labels to integer class and 6 decimal place floats
1114
+ return [[int(c), *(round(x, 4) for x in points)] for c, *points in labels]
1115
+
1116
+ for split in 'train', 'val', 'test':
1117
+ if self.data.get(split) is None:
1118
+ self.stats[split] = None # i.e. no test set
1119
+ continue
1120
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
1121
+ x = np.array([
1122
+ np.bincount(label[:, 0].astype(int), minlength=self.data['nc'])
1123
+ for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics')]) # shape(128x80)
1124
+ self.stats[split] = {
1125
+ 'instance_stats': {
1126
+ 'total': int(x.sum()),
1127
+ 'per_class': x.sum(0).tolist()},
1128
+ 'image_stats': {
1129
+ 'total': dataset.n,
1130
+ 'unlabelled': int(np.all(x == 0, 1).sum()),
1131
+ 'per_class': (x > 0).sum(0).tolist()},
1132
+ 'labels': [{
1133
+ str(Path(k).name): _round(v.tolist())} for k, v in zip(dataset.im_files, dataset.labels)]}
1134
+
1135
+ # Save, print and return
1136
+ if save:
1137
+ stats_path = self.hub_dir / 'stats.json'
1138
+ print(f'Saving {stats_path.resolve()}...')
1139
+ with open(stats_path, 'w') as f:
1140
+ json.dump(self.stats, f) # save stats.json
1141
+ if verbose:
1142
+ print(json.dumps(self.stats, indent=2, sort_keys=False))
1143
+ return self.stats
1144
+
1145
+ def process_images(self):
1146
+ # Compress images for Ultralytics HUB
1147
+ for split in 'train', 'val', 'test':
1148
+ if self.data.get(split) is None:
1149
+ continue
1150
+ dataset = LoadImagesAndLabels(self.data[split]) # load dataset
1151
+ desc = f'{split} images'
1152
+ for _ in tqdm(ThreadPool(NUM_THREADS).imap(self._hub_ops, dataset.im_files), total=dataset.n, desc=desc):
1153
+ pass
1154
+ print(f'Done. All images saved to {self.im_dir}')
1155
+ return self.im_dir
1156
+
1157
+
1158
+ # Classification dataloaders -------------------------------------------------------------------------------------------
1159
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
1160
+ """
1161
+ YOLOv5 Classification Dataset.
1162
+ Arguments
1163
+ root: Dataset path
1164
+ transform: torchvision transforms, used by default
1165
+ album_transform: Albumentations transforms, used if installed
1166
+ """
1167
+
1168
+ def __init__(self, root, augment, imgsz, cache=False):
1169
+ super().__init__(root=root)
1170
+ self.torch_transforms = classify_transforms(imgsz)
1171
+ self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
1172
+ self.cache_ram = cache is True or cache == 'ram'
1173
+ self.cache_disk = cache == 'disk'
1174
+ self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
1175
+
1176
+ def __getitem__(self, i):
1177
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
1178
+ if self.cache_ram and im is None:
1179
+ im = self.samples[i][3] = cv2.imread(f)
1180
+ elif self.cache_disk:
1181
+ if not fn.exists(): # load npy
1182
+ np.save(fn.as_posix(), cv2.imread(f))
1183
+ im = np.load(fn)
1184
+ else: # read image
1185
+ im = cv2.imread(f) # BGR
1186
+ if self.album_transforms:
1187
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
1188
+ else:
1189
+ sample = self.torch_transforms(im)
1190
+ return sample, j
1191
+
1192
+
1193
+ def create_classification_dataloader(path,
1194
+ imgsz=224,
1195
+ batch_size=16,
1196
+ augment=True,
1197
+ cache=False,
1198
+ rank=-1,
1199
+ workers=8,
1200
+ shuffle=True):
1201
+ # Returns Dataloader object to be used with YOLOv5 Classifier
1202
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
1203
+ dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
1204
+ batch_size = min(batch_size, len(dataset))
1205
+ nd = torch.cuda.device_count()
1206
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
1207
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
1208
+ generator = torch.Generator()
1209
+ generator.manual_seed(6148914691236517205 + RANK)
1210
+ return InfiniteDataLoader(dataset,
1211
+ batch_size=batch_size,
1212
+ shuffle=shuffle and sampler is None,
1213
+ num_workers=nw,
1214
+ sampler=sampler,
1215
+ pin_memory=PIN_MEMORY,
1216
+ worker_init_fn=seed_worker,
1217
+ generator=generator) # or DataLoader(persistent_workers=True)