aakashv100 commited on
Commit
9201e90
·
1 Parent(s): 666297a

added lfs

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +33 -0
  2. yolov9/LICENSE.md +674 -0
  3. yolov9/README.md +335 -0
  4. yolov9/__pycache__/export.cpython-311.pyc +0 -0
  5. yolov9/__pycache__/val_dual.cpython-311.pyc +0 -0
  6. yolov9/benchmarks.py +142 -0
  7. yolov9/classify/predict.py +224 -0
  8. yolov9/classify/train.py +333 -0
  9. yolov9/classify/val.py +170 -0
  10. yolov9/detect.py +231 -0
  11. yolov9/detect_dual.py +232 -0
  12. yolov9/export.py +686 -0
  13. yolov9/figure/horses_prediction.jpg +0 -0
  14. yolov9/figure/multitask.png +3 -0
  15. yolov9/figure/performance.png +3 -0
  16. yolov9/hubconf.py +107 -0
  17. yolov9/models/__init__.py +1 -0
  18. yolov9/models/__pycache__/__init__.cpython-311.pyc +0 -0
  19. yolov9/models/__pycache__/common.cpython-311.pyc +0 -0
  20. yolov9/models/__pycache__/experimental.cpython-311.pyc +0 -0
  21. yolov9/models/__pycache__/yolo.cpython-311.pyc +0 -0
  22. yolov9/models/common.py +1233 -0
  23. yolov9/models/detect/gelan-c.yaml +80 -0
  24. yolov9/models/detect/gelan-e.yaml +121 -0
  25. yolov9/models/detect/gelan-m.yaml +80 -0
  26. yolov9/models/detect/gelan-s.yaml +80 -0
  27. yolov9/models/detect/gelan-t.yaml +80 -0
  28. yolov9/models/detect/gelan.yaml +80 -0
  29. yolov9/models/detect/yolov7-af.yaml +137 -0
  30. yolov9/models/detect/yolov9-cf.yaml +124 -0
  31. yolov9/models/detect/yolov9-m.yaml +117 -0
  32. yolov9/models/detect/yolov9-s.yaml +97 -0
  33. yolov9/models/detect/yolov9-t.yaml +97 -0
  34. yolov9/models/detect/yolov9.yaml +117 -0
  35. yolov9/models/experimental.py +275 -0
  36. yolov9/models/hub/anchors.yaml +59 -0
  37. yolov9/models/hub/yolov3-spp.yaml +51 -0
  38. yolov9/models/hub/yolov3-tiny.yaml +41 -0
  39. yolov9/models/hub/yolov3.yaml +51 -0
  40. yolov9/models/panoptic/gelan-c-pan.yaml +80 -0
  41. yolov9/models/panoptic/yolov7-af-pan.yaml +137 -0
  42. yolov9/models/segment/gelan-c-dseg.yaml +84 -0
  43. yolov9/models/segment/gelan-c-seg.yaml +80 -0
  44. yolov9/models/segment/yolov7-af-seg.yaml +136 -0
  45. yolov9/models/segment/yolov9-c-dseg.yaml +130 -0
  46. yolov9/models/tf.py +596 -0
  47. yolov9/models/yolo.py +818 -0
  48. yolov9/panoptic/predict.py +246 -0
  49. yolov9/panoptic/train.py +662 -0
  50. yolov9/panoptic/val.py +597 -0
app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import uuid
4
+ from PIL import Image
5
+ import os
6
+
7
+
8
+ def inference(img):
9
+ name_tag = str(uuid.uuid4())
10
+ script_command = f"python yolov9/detect_dual.py --source {img} --img 640 --device cpu --weights yolov9/runs/train/exp2/weights/best.pt --name {name_tag}"
11
+ os.system(script_command)
12
+ output_file = f"yolov9/runs/detect/{name_tag}/{img.split('/')[-1]}"
13
+ return output_file
14
+
15
+ title = "Sketch GUI Element Detection"
16
+
17
+ description = "This is a demo for detecting GUI elements in a sketch image. Upload a sketch image and the model will detect the GUI elements in the image."
18
+
19
+ img_input = gr.Image(type="filepath", label="Upload a sketch image", width=300, height=300)
20
+ prediction_output = gr.Image(label="Output Image", width=640, height=640)
21
+
22
+ example_lst = [
23
+ ["test_images/Shipping-1.png"],
24
+ ]
25
+
26
+ demo = gr.Interface(fn=inference,
27
+ inputs=img_input,
28
+ outputs=prediction_output,
29
+ title=title,
30
+ description=description,
31
+ examples=example_lst)
32
+
33
+ demo.launch(debug=True)
yolov9/LICENSE.md ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
yolov9/README.md ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ Implementation of paper - [YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information](https://arxiv.org/abs/2402.13616)
4
+
5
+ [![arxiv.org](http://img.shields.io/badge/cs.CV-arXiv%3A2402.13616-B31B1B.svg)](https://arxiv.org/abs/2402.13616)
6
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/kadirnar/Yolov9)
7
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/merve/yolov9)
8
+ [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-yolov9-object-detection-on-custom-dataset.ipynb)
9
+ [![OpenCV](https://img.shields.io/badge/OpenCV-BlogPost-black?logo=opencv&labelColor=blue&color=black)](https://learnopencv.com/yolov9-advancing-the-yolo-legacy/)
10
+
11
+ <div align="center">
12
+ <a href="./">
13
+ <img src="./figure/performance.png" width="79%"/>
14
+ </a>
15
+ </div>
16
+
17
+
18
+ ## Performance
19
+
20
+ MS COCO
21
+
22
+ | Model | Test Size | AP<sup>val</sup> | AP<sub>50</sub><sup>val</sup> | AP<sub>75</sub><sup>val</sup> | Param. | FLOPs |
23
+ | :-- | :-: | :-: | :-: | :-: | :-: | :-: |
24
+ | [**YOLOv9-T**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-t-converted.pt) | 640 | **38.3%** | **53.1%** | **41.3%** | **2.0M** | **7.7G** |
25
+ | [**YOLOv9-S**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-s-converted.pt) | 640 | **46.8%** | **63.4%** | **50.7%** | **7.1M** | **26.4G** |
26
+ | [**YOLOv9-M**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-m-converted.pt) | 640 | **51.4%** | **68.1%** | **56.1%** | **20.0M** | **76.3G** |
27
+ | [**YOLOv9-C**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c-converted.pt) | 640 | **53.0%** | **70.2%** | **57.8%** | **25.3M** | **102.1G** |
28
+ | [**YOLOv9-E**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-e-converted.pt) | 640 | **55.6%** | **72.8%** | **60.6%** | **57.3M** | **189.0G** |
29
+ <!-- | [**YOLOv9 (ReLU)**]() | 640 | **51.9%** | **69.1%** | **56.5%** | **25.3M** | **102.1G** | -->
30
+
31
+ <!-- tiny, small, and medium models will be released after the paper be accepted and published. -->
32
+
33
+ ## Useful Links
34
+
35
+ <details><summary> <b>Expand</b> </summary>
36
+
37
+ Custom training: https://github.com/WongKinYiu/yolov9/issues/30#issuecomment-1960955297
38
+
39
+ ONNX export: https://github.com/WongKinYiu/yolov9/issues/2#issuecomment-1960519506 https://github.com/WongKinYiu/yolov9/issues/40#issue-2150697688 https://github.com/WongKinYiu/yolov9/issues/130#issue-2162045461
40
+
41
+ ONNX export for segmentation: https://github.com/WongKinYiu/yolov9/issues/260#issue-2191162150
42
+
43
+ TensorRT inference: https://github.com/WongKinYiu/yolov9/issues/143#issuecomment-1975049660 https://github.com/WongKinYiu/yolov9/issues/34#issue-2150393690 https://github.com/WongKinYiu/yolov9/issues/79#issue-2153547004 https://github.com/WongKinYiu/yolov9/issues/143#issue-2164002309
44
+
45
+ QAT TensorRT: https://github.com/WongKinYiu/yolov9/issues/327#issue-2229284136 https://github.com/WongKinYiu/yolov9/issues/253#issue-2189520073
46
+
47
+ TensorRT inference for segmentation: https://github.com/WongKinYiu/yolov9/issues/446
48
+
49
+ TFLite: https://github.com/WongKinYiu/yolov9/issues/374#issuecomment-2065751706
50
+
51
+ OpenVINO: https://github.com/WongKinYiu/yolov9/issues/164#issue-2168540003
52
+
53
+ C# ONNX inference: https://github.com/WongKinYiu/yolov9/issues/95#issue-2155974619
54
+
55
+ C# OpenVINO inference: https://github.com/WongKinYiu/yolov9/issues/95#issuecomment-1968131244
56
+
57
+ OpenCV: https://github.com/WongKinYiu/yolov9/issues/113#issuecomment-1971327672
58
+
59
+ Hugging Face demo: https://github.com/WongKinYiu/yolov9/issues/45#issuecomment-1961496943
60
+
61
+ CoLab demo: https://github.com/WongKinYiu/yolov9/pull/18
62
+
63
+ ONNXSlim export: https://github.com/WongKinYiu/yolov9/pull/37
64
+
65
+ YOLOv9 ROS: https://github.com/WongKinYiu/yolov9/issues/144#issue-2164210644
66
+
67
+ YOLOv9 ROS TensorRT: https://github.com/WongKinYiu/yolov9/issues/145#issue-2164218595
68
+
69
+ YOLOv9 Julia: https://github.com/WongKinYiu/yolov9/issues/141#issuecomment-1973710107
70
+
71
+ YOLOv9 MLX: https://github.com/WongKinYiu/yolov9/issues/258#issue-2190586540
72
+
73
+ YOLOv9 StrongSORT with OSNet: https://github.com/WongKinYiu/yolov9/issues/299#issue-2212093340
74
+
75
+ YOLOv9 ByteTrack: https://github.com/WongKinYiu/yolov9/issues/78#issue-2153512879
76
+
77
+ YOLOv9 DeepSORT: https://github.com/WongKinYiu/yolov9/issues/98#issue-2156172319
78
+
79
+ YOLOv9 counting: https://github.com/WongKinYiu/yolov9/issues/84#issue-2153904804
80
+
81
+ YOLOv9 speed estimation: https://github.com/WongKinYiu/yolov9/issues/456
82
+
83
+ YOLOv9 face detection: https://github.com/WongKinYiu/yolov9/issues/121#issue-2160218766
84
+
85
+ YOLOv9 segmentation onnxruntime: https://github.com/WongKinYiu/yolov9/issues/151#issue-2165667350
86
+
87
+ Comet logging: https://github.com/WongKinYiu/yolov9/pull/110
88
+
89
+ MLflow logging: https://github.com/WongKinYiu/yolov9/pull/87
90
+
91
+ AnyLabeling tool: https://github.com/WongKinYiu/yolov9/issues/48#issue-2152139662
92
+
93
+ AX650N deploy: https://github.com/WongKinYiu/yolov9/issues/96#issue-2156115760
94
+
95
+ Conda environment: https://github.com/WongKinYiu/yolov9/pull/93
96
+
97
+ AutoDL docker environment: https://github.com/WongKinYiu/yolov9/issues/112#issue-2158203480
98
+
99
+ </details>
100
+
101
+
102
+ ## Installation
103
+
104
+ Docker environment (recommended)
105
+ <details><summary> <b>Expand</b> </summary>
106
+
107
+ ``` shell
108
+ # create the docker container, you can change the share memory size if you have more.
109
+ nvidia-docker run --name yolov9 -it -v your_coco_path/:/coco/ -v your_code_path/:/yolov9 --shm-size=64g nvcr.io/nvidia/pytorch:21.11-py3
110
+
111
+ # apt install required packages
112
+ apt update
113
+ apt install -y zip htop screen libgl1-mesa-glx
114
+
115
+ # pip install required packages
116
+ pip install seaborn thop
117
+
118
+ # go to code folder
119
+ cd /yolov9
120
+ ```
121
+
122
+ </details>
123
+
124
+
125
+ ## Evaluation
126
+
127
+ [`yolov9-s-converted.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-s-converted.pt) [`yolov9-m-converted.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-m-converted.pt) [`yolov9-c-converted.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c-converted.pt) [`yolov9-e-converted.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-e-converted.pt)
128
+ [`yolov9-s.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-s.pt) [`yolov9-m.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-m.pt) [`yolov9-c.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c.pt) [`yolov9-e.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-e.pt)
129
+ [`gelan-s.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-s.pt) [`gelan-m.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-m.pt) [`gelan-c.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c.pt) [`gelan-e.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-e.pt)
130
+
131
+ ``` shell
132
+ # evaluate converted yolov9 models
133
+ python val.py --data data/coco.yaml --img 640 --batch 32 --conf 0.001 --iou 0.7 --device 0 --weights './yolov9-c-converted.pt' --save-json --name yolov9_c_c_640_val
134
+
135
+ # evaluate yolov9 models
136
+ # python val_dual.py --data data/coco.yaml --img 640 --batch 32 --conf 0.001 --iou 0.7 --device 0 --weights './yolov9-c.pt' --save-json --name yolov9_c_640_val
137
+
138
+ # evaluate gelan models
139
+ # python val.py --data data/coco.yaml --img 640 --batch 32 --conf 0.001 --iou 0.7 --device 0 --weights './gelan-c.pt' --save-json --name gelan_c_640_val
140
+ ```
141
+
142
+ You will get the results:
143
+
144
+ ```
145
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.530
146
+ Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.702
147
+ Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.578
148
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.362
149
+ Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.585
150
+ Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.693
151
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.392
152
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.652
153
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.702
154
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.541
155
+ Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.760
156
+ Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.844
157
+ ```
158
+
159
+
160
+ ## Training
161
+
162
+ Data preparation
163
+
164
+ ``` shell
165
+ bash scripts/get_coco.sh
166
+ ```
167
+
168
+ * Download MS COCO dataset images ([train](http://images.cocodataset.org/zips/train2017.zip), [val](http://images.cocodataset.org/zips/val2017.zip), [test](http://images.cocodataset.org/zips/test2017.zip)) and [labels](https://github.com/WongKinYiu/yolov7/releases/download/v0.1/coco2017labels-segments.zip). If you have previously used a different version of YOLO, we strongly recommend that you delete `train2017.cache` and `val2017.cache` files, and redownload [labels](https://github.com/WongKinYiu/yolov7/releases/download/v0.1/coco2017labels-segments.zip)
169
+
170
+ Single GPU training
171
+
172
+ ``` shell
173
+ # train yolov9 models
174
+ python train_dual.py --workers 8 --device 0 --batch 16 --data data/coco.yaml --img 640 --cfg models/detect/yolov9-c.yaml --weights '' --name yolov9-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
175
+
176
+ # train gelan models
177
+ # python train.py --workers 8 --device 0 --batch 32 --data data/coco.yaml --img 640 --cfg models/detect/gelan-c.yaml --weights '' --name gelan-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
178
+ ```
179
+
180
+ Multiple GPU training
181
+
182
+ ``` shell
183
+ # train yolov9 models
184
+ python -m torch.distributed.launch --nproc_per_node 8 --master_port 9527 train_dual.py --workers 8 --device 0,1,2,3,4,5,6,7 --sync-bn --batch 128 --data data/coco.yaml --img 640 --cfg models/detect/yolov9-c.yaml --weights '' --name yolov9-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
185
+
186
+ # train gelan models
187
+ # python -m torch.distributed.launch --nproc_per_node 4 --master_port 9527 train.py --workers 8 --device 0,1,2,3 --sync-bn --batch 128 --data data/coco.yaml --img 640 --cfg models/detect/gelan-c.yaml --weights '' --name gelan-c --hyp hyp.scratch-high.yaml --min-items 0 --epochs 500 --close-mosaic 15
188
+ ```
189
+
190
+
191
+ ## Re-parameterization
192
+
193
+ See [reparameterization.ipynb](https://github.com/WongKinYiu/yolov9/blob/main/tools/reparameterization.ipynb).
194
+
195
+
196
+ ## Inference
197
+
198
+ <div align="center">
199
+ <a href="./">
200
+ <img src="./figure/horses_prediction.jpg" width="49%"/>
201
+ </a>
202
+ </div>
203
+
204
+ ``` shell
205
+ # inference converted yolov9 models
206
+ python detect.py --source './data/images/horses.jpg' --img 640 --device 0 --weights './yolov9-c-converted.pt' --name yolov9_c_c_640_detect
207
+
208
+ # inference yolov9 models
209
+ # python detect_dual.py --source './data/images/horses.jpg' --img 640 --device 0 --weights './yolov9-c.pt' --name yolov9_c_640_detect
210
+
211
+ # inference gelan models
212
+ # python detect.py --source './data/images/horses.jpg' --img 640 --device 0 --weights './gelan-c.pt' --name gelan_c_c_640_detect
213
+ ```
214
+
215
+
216
+ ## Citation
217
+
218
+ ```
219
+ @article{wang2024yolov9,
220
+ title={{YOLOv9}: Learning What You Want to Learn Using Programmable Gradient Information},
221
+ author={Wang, Chien-Yao and Liao, Hong-Yuan Mark},
222
+ booktitle={arXiv preprint arXiv:2402.13616},
223
+ year={2024}
224
+ }
225
+ ```
226
+
227
+ ```
228
+ @article{chang2023yolor,
229
+ title={{YOLOR}-Based Multi-Task Learning},
230
+ author={Chang, Hung-Shuo and Wang, Chien-Yao and Wang, Richard Robert and Chou, Gene and Liao, Hong-Yuan Mark},
231
+ journal={arXiv preprint arXiv:2309.16921},
232
+ year={2023}
233
+ }
234
+ ```
235
+
236
+
237
+ ## Teaser
238
+
239
+ Parts of code of [YOLOR-Based Multi-Task Learning](https://arxiv.org/abs/2309.16921) are released in the repository.
240
+
241
+ <div align="center">
242
+ <a href="./">
243
+ <img src="./figure/multitask.png" width="99%"/>
244
+ </a>
245
+ </div>
246
+
247
+ #### Object Detection
248
+
249
+ [`gelan-c-det.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-det.pt)
250
+
251
+ `object detection`
252
+
253
+ ``` shell
254
+ # coco/labels/{split}/*.txt
255
+ # bbox or polygon (1 instance 1 line)
256
+ python train.py --workers 8 --device 0 --batch 32 --data data/coco.yaml --img 640 --cfg models/detect/gelan-c.yaml --weights '' --name gelan-c-det --hyp hyp.scratch-high.yaml --min-items 0 --epochs 300 --close-mosaic 10
257
+ ```
258
+
259
+ | Model | Test Size | Param. | FLOPs | AP<sup>box</sup> |
260
+ | :-- | :-: | :-: | :-: | :-: |
261
+ | [**GELAN-C-DET**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-det.pt) | 640 | 25.3M | 102.1G |**52.3%** |
262
+ | [**YOLOv9-C-DET**]() | 640 | 25.3M | 102.1G | **53.0%** |
263
+
264
+ #### Instance Segmentation
265
+
266
+ [`gelan-c-seg.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-seg.pt)
267
+
268
+ `object detection` `instance segmentation`
269
+
270
+ ``` shell
271
+ # coco/labels/{split}/*.txt
272
+ # polygon (1 instance 1 line)
273
+ python segment/train.py --workers 8 --device 0 --batch 32 --data coco.yaml --img 640 --cfg models/segment/gelan-c-seg.yaml --weights '' --name gelan-c-seg --hyp hyp.scratch-high.yaml --no-overlap --epochs 300 --close-mosaic 10
274
+ ```
275
+
276
+ | Model | Test Size | Param. | FLOPs | AP<sup>box</sup> | AP<sup>mask</sup> |
277
+ | :-- | :-: | :-: | :-: | :-: | :-: |
278
+ | [**GELAN-C-SEG**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-seg.pt) | 640 | 27.4M | 144.6G | **52.3%** | **42.4%** |
279
+ | [**YOLOv9-C-SEG**]() | 640 | 27.4M | 145.5G | **53.3%** | **43.5%** |
280
+
281
+ #### Panoptic Segmentation
282
+
283
+ [`gelan-c-pan.pt`](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-pan.pt)
284
+
285
+ `object detection` `instance segmentation` `semantic segmentation` `stuff segmentation` `panoptic segmentation`
286
+
287
+ ``` shell
288
+ # coco/labels/{split}/*.txt
289
+ # polygon (1 instance 1 line)
290
+ # coco/stuff/{split}/*.txt
291
+ # polygon (1 semantic 1 line)
292
+ python panoptic/train.py --workers 8 --device 0 --batch 32 --data coco.yaml --img 640 --cfg models/panoptic/gelan-c-pan.yaml --weights '' --name gelan-c-pan --hyp hyp.scratch-high.yaml --no-overlap --epochs 300 --close-mosaic 10
293
+ ```
294
+
295
+ | Model | Test Size | Param. | FLOPs | AP<sup>box</sup> | AP<sup>mask</sup> | mIoU<sub>164k/10k</sub><sup>semantic</sup> | mIoU<sup>stuff</sup> | PQ<sup>panoptic</sup> |
296
+ | :-- | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
297
+ | [**GELAN-C-PAN**](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-pan.pt) | 640 | 27.6M | 146.7G | **52.6%** | **42.5%** | **39.0%/48.3%** | **52.7%** | **39.4%** |
298
+ | [**YOLOv9-C-PAN**]() | 640 | 28.8M | 187.0G | **52.7%** | **43.0%** | **39.8%/-** | **52.2%** | **40.5%** |
299
+
300
+ #### Image Captioning (not yet released)
301
+
302
+ <!--[`gelan-c-cap.pt`]()-->
303
+
304
+ `object detection` `instance segmentation` `semantic segmentation` `stuff segmentation` `panoptic segmentation` `image captioning`
305
+
306
+ ``` shell
307
+ # coco/labels/{split}/*.txt
308
+ # polygon (1 instance 1 line)
309
+ # coco/stuff/{split}/*.txt
310
+ # polygon (1 semantic 1 line)
311
+ # coco/annotations/*.json
312
+ # json (1 split 1 file)
313
+ python caption/train.py --workers 8 --device 0 --batch 32 --data coco.yaml --img 640 --cfg models/caption/gelan-c-cap.yaml --weights '' --name gelan-c-cap --hyp hyp.scratch-high.yaml --no-overlap --epochs 300 --close-mosaic 10
314
+ ```
315
+
316
+ | Model | Test Size | Param. | FLOPs | AP<sup>box</sup> | AP<sup>mask</sup> | mIoU<sub>164k/10k</sub><sup>semantic</sup> | mIoU<sup>stuff</sup> | PQ<sup>panoptic</sup> | BLEU@4<sup>caption</sup> | CIDEr<sup>caption</sup> |
317
+ | :-- | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
318
+ | [**GELAN-C-CAP**]() | 640 | 47.5M | - | **51.9%** | **42.6%** | **42.5%/-** | **56.5%** | **41.7%** | **38.8** | **122.3** |
319
+ | [**YOLOv9-C-CAP**]() | 640 | 47.5M | - | **52.1%** | **42.6%** | **43.0%/-** | **56.4%** | **42.1%** | **39.1** | **122.0** |
320
+ <!--| [**YOLOR-MT**]() | 640 | 79.3M | - | **51.0%** | **41.7%** | **-/49.6%** | **55.9%** | **40.5%** | **35.7** | **112.7** |-->
321
+
322
+
323
+ ## Acknowledgements
324
+
325
+ <details><summary> <b>Expand</b> </summary>
326
+
327
+ * [https://github.com/AlexeyAB/darknet](https://github.com/AlexeyAB/darknet)
328
+ * [https://github.com/WongKinYiu/yolor](https://github.com/WongKinYiu/yolor)
329
+ * [https://github.com/WongKinYiu/yolov7](https://github.com/WongKinYiu/yolov7)
330
+ * [https://github.com/VDIGPKU/DynamicDet](https://github.com/VDIGPKU/DynamicDet)
331
+ * [https://github.com/DingXiaoH/RepVGG](https://github.com/DingXiaoH/RepVGG)
332
+ * [https://github.com/ultralytics/yolov5](https://github.com/ultralytics/yolov5)
333
+ * [https://github.com/meituan/YOLOv6](https://github.com/meituan/YOLOv6)
334
+
335
+ </details>
yolov9/__pycache__/export.cpython-311.pyc ADDED
Binary file (47.9 kB). View file
 
yolov9/__pycache__/val_dual.cpython-311.pyc ADDED
Binary file (28.1 kB). View file
 
yolov9/benchmarks.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import platform
3
+ import sys
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import pandas as pd
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ # ROOT = ROOT.relative_to(Path.cwd()) # relative
14
+
15
+ import export
16
+ from models.experimental import attempt_load
17
+ from models.yolo import SegmentationModel
18
+ from segment.val import run as val_seg
19
+ from utils import notebook_init
20
+ from utils.general import LOGGER, check_yaml, file_size, print_args
21
+ from utils.torch_utils import select_device
22
+ from val import run as val_det
23
+
24
+
25
+ def run(
26
+ weights=ROOT / 'yolo.pt', # weights path
27
+ imgsz=640, # inference size (pixels)
28
+ batch_size=1, # batch size
29
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
30
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
31
+ half=False, # use FP16 half-precision inference
32
+ test=False, # test exports only
33
+ pt_only=False, # test PyTorch only
34
+ hard_fail=False, # throw error on benchmark failure
35
+ ):
36
+ y, t = [], time.time()
37
+ device = select_device(device)
38
+ model_type = type(attempt_load(weights, fuse=False)) # DetectionModel, SegmentationModel, etc.
39
+ for i, (name, f, suffix, cpu, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, CPU, GPU)
40
+ try:
41
+ assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
42
+ assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
43
+ if 'cpu' in device.type:
44
+ assert cpu, 'inference not supported on CPU'
45
+ if 'cuda' in device.type:
46
+ assert gpu, 'inference not supported on GPU'
47
+
48
+ # Export
49
+ if f == '-':
50
+ w = weights # PyTorch format
51
+ else:
52
+ w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others
53
+ assert suffix in str(w), 'export failed'
54
+
55
+ # Validate
56
+ if model_type == SegmentationModel:
57
+ result = val_seg(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
58
+ metric = result[0][7] # (box(p, r, map50, map), mask(p, r, map50, map), *loss(box, obj, cls))
59
+ else: # DetectionModel:
60
+ result = val_det(data, w, batch_size, imgsz, plots=False, device=device, task='speed', half=half)
61
+ metric = result[0][3] # (p, r, map50, map, *loss(box, obj, cls))
62
+ speed = result[2][1] # times (preprocess, inference, postprocess)
63
+ y.append([name, round(file_size(w), 1), round(metric, 4), round(speed, 2)]) # MB, mAP, t_inference
64
+ except Exception as e:
65
+ if hard_fail:
66
+ assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}'
67
+ LOGGER.warning(f'WARNING ⚠️ Benchmark failure for {name}: {e}')
68
+ y.append([name, None, None, None]) # mAP, t_inference
69
+ if pt_only and i == 0:
70
+ break # break after PyTorch
71
+
72
+ # Print results
73
+ LOGGER.info('\n')
74
+ parse_opt()
75
+ notebook_init() # print system info
76
+ c = ['Format', 'Size (MB)', 'mAP50-95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
77
+ py = pd.DataFrame(y, columns=c)
78
+ LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
79
+ LOGGER.info(str(py if map else py.iloc[:, :2]))
80
+ if hard_fail and isinstance(hard_fail, str):
81
+ metrics = py['mAP50-95'].array # values to compare to floor
82
+ floor = eval(hard_fail) # minimum metric floor to pass
83
+ assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: mAP50-95 < floor {floor}'
84
+ return py
85
+
86
+
87
+ def test(
88
+ weights=ROOT / 'yolo.pt', # weights path
89
+ imgsz=640, # inference size (pixels)
90
+ batch_size=1, # batch size
91
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
92
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
93
+ half=False, # use FP16 half-precision inference
94
+ test=False, # test exports only
95
+ pt_only=False, # test PyTorch only
96
+ hard_fail=False, # throw error on benchmark failure
97
+ ):
98
+ y, t = [], time.time()
99
+ device = select_device(device)
100
+ for i, (name, f, suffix, gpu) in export.export_formats().iterrows(): # index, (name, file, suffix, gpu-capable)
101
+ try:
102
+ w = weights if f == '-' else \
103
+ export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights
104
+ assert suffix in str(w), 'export failed'
105
+ y.append([name, True])
106
+ except Exception:
107
+ y.append([name, False]) # mAP, t_inference
108
+
109
+ # Print results
110
+ LOGGER.info('\n')
111
+ parse_opt()
112
+ notebook_init() # print system info
113
+ py = pd.DataFrame(y, columns=['Format', 'Export'])
114
+ LOGGER.info(f'\nExports complete ({time.time() - t:.2f}s)')
115
+ LOGGER.info(str(py))
116
+ return py
117
+
118
+
119
+ def parse_opt():
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
122
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
123
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
124
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
125
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
126
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
127
+ parser.add_argument('--test', action='store_true', help='test exports only')
128
+ parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
129
+ parser.add_argument('--hard-fail', nargs='?', const=True, default=False, help='Exception on error or < min metric')
130
+ opt = parser.parse_args()
131
+ opt.data = check_yaml(opt.data) # check YAML
132
+ print_args(vars(opt))
133
+ return opt
134
+
135
+
136
+ def main(opt):
137
+ test(**vars(opt)) if opt.test else run(**vars(opt))
138
+
139
+
140
+ if __name__ == "__main__":
141
+ opt = parse_opt()
142
+ main(opt)
yolov9/classify/predict.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Run YOLOv5 classification inference on images, videos, directories, globs, YouTube, webcam, streams, etc.
4
+
5
+ Usage - sources:
6
+ $ python classify/predict.py --weights yolov5s-cls.pt --source 0 # webcam
7
+ img.jpg # image
8
+ vid.mp4 # video
9
+ screen # screenshot
10
+ path/ # directory
11
+ 'path/*.jpg' # glob
12
+ 'https://youtu.be/Zgi9g1ksQHc' # YouTube
13
+ 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP stream
14
+
15
+ Usage - formats:
16
+ $ python classify/predict.py --weights yolov5s-cls.pt # PyTorch
17
+ yolov5s-cls.torchscript # TorchScript
18
+ yolov5s-cls.onnx # ONNX Runtime or OpenCV DNN with --dnn
19
+ yolov5s-cls_openvino_model # OpenVINO
20
+ yolov5s-cls.engine # TensorRT
21
+ yolov5s-cls.mlmodel # CoreML (macOS-only)
22
+ yolov5s-cls_saved_model # TensorFlow SavedModel
23
+ yolov5s-cls.pb # TensorFlow GraphDef
24
+ yolov5s-cls.tflite # TensorFlow Lite
25
+ yolov5s-cls_edgetpu.tflite # TensorFlow Edge TPU
26
+ yolov5s-cls_paddle_model # PaddlePaddle
27
+ """
28
+
29
+ import argparse
30
+ import os
31
+ import platform
32
+ import sys
33
+ from pathlib import Path
34
+
35
+ import torch
36
+ import torch.nn.functional as F
37
+
38
+ FILE = Path(__file__).resolve()
39
+ ROOT = FILE.parents[1] # YOLOv5 root directory
40
+ if str(ROOT) not in sys.path:
41
+ sys.path.append(str(ROOT)) # add ROOT to PATH
42
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
43
+
44
+ from models.common import DetectMultiBackend
45
+ from utils.augmentations import classify_transforms
46
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
47
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
48
+ increment_path, print_args, strip_optimizer)
49
+ from utils.plots import Annotator
50
+ from utils.torch_utils import select_device, smart_inference_mode
51
+
52
+
53
+ @smart_inference_mode()
54
+ def run(
55
+ weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
56
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
57
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
58
+ imgsz=(224, 224), # inference size (height, width)
59
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
60
+ view_img=False, # show results
61
+ save_txt=False, # save results to *.txt
62
+ nosave=False, # do not save images/videos
63
+ augment=False, # augmented inference
64
+ visualize=False, # visualize features
65
+ update=False, # update all models
66
+ project=ROOT / 'runs/predict-cls', # save results to project/name
67
+ name='exp', # save results to project/name
68
+ exist_ok=False, # existing project/name ok, do not increment
69
+ half=False, # use FP16 half-precision inference
70
+ dnn=False, # use OpenCV DNN for ONNX inference
71
+ vid_stride=1, # video frame-rate stride
72
+ ):
73
+ source = str(source)
74
+ save_img = not nosave and not source.endswith('.txt') # save inference images
75
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
76
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
77
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
78
+ screenshot = source.lower().startswith('screen')
79
+ if is_url and is_file:
80
+ source = check_file(source) # download
81
+
82
+ # Directories
83
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
84
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
85
+
86
+ # Load model
87
+ device = select_device(device)
88
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
89
+ stride, names, pt = model.stride, model.names, model.pt
90
+ imgsz = check_img_size(imgsz, s=stride) # check image size
91
+
92
+ # Dataloader
93
+ bs = 1 # batch_size
94
+ if webcam:
95
+ view_img = check_imshow(warn=True)
96
+ dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
97
+ bs = len(dataset)
98
+ elif screenshot:
99
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
100
+ else:
101
+ dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
102
+ vid_path, vid_writer = [None] * bs, [None] * bs
103
+
104
+ # Run inference
105
+ model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
106
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
107
+ for path, im, im0s, vid_cap, s in dataset:
108
+ with dt[0]:
109
+ im = torch.Tensor(im).to(model.device)
110
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
111
+ if len(im.shape) == 3:
112
+ im = im[None] # expand for batch dim
113
+
114
+ # Inference
115
+ with dt[1]:
116
+ results = model(im)
117
+
118
+ # Post-process
119
+ with dt[2]:
120
+ pred = F.softmax(results, dim=1) # probabilities
121
+
122
+ # Process predictions
123
+ for i, prob in enumerate(pred): # per image
124
+ seen += 1
125
+ if webcam: # batch_size >= 1
126
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
127
+ s += f'{i}: '
128
+ else:
129
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
130
+
131
+ p = Path(p) # to Path
132
+ save_path = str(save_dir / p.name) # im.jpg
133
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
134
+
135
+ s += '%gx%g ' % im.shape[2:] # print string
136
+ annotator = Annotator(im0, example=str(names), pil=True)
137
+
138
+ # Print results
139
+ top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
140
+ s += f"{', '.join(f'{names[j]} {prob[j]:.2f}' for j in top5i)}, "
141
+
142
+ # Write results
143
+ text = '\n'.join(f'{prob[j]:.2f} {names[j]}' for j in top5i)
144
+ if save_img or view_img: # Add bbox to image
145
+ annotator.text((32, 32), text, txt_color=(255, 255, 255))
146
+ if save_txt: # Write to file
147
+ with open(f'{txt_path}.txt', 'a') as f:
148
+ f.write(text + '\n')
149
+
150
+ # Stream results
151
+ im0 = annotator.result()
152
+ if view_img:
153
+ if platform.system() == 'Linux' and p not in windows:
154
+ windows.append(p)
155
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
156
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
157
+ cv2.imshow(str(p), im0)
158
+ cv2.waitKey(1) # 1 millisecond
159
+
160
+ # Save results (image with detections)
161
+ if save_img:
162
+ if dataset.mode == 'image':
163
+ cv2.imwrite(save_path, im0)
164
+ else: # 'video' or 'stream'
165
+ if vid_path[i] != save_path: # new video
166
+ vid_path[i] = save_path
167
+ if isinstance(vid_writer[i], cv2.VideoWriter):
168
+ vid_writer[i].release() # release previous video writer
169
+ if vid_cap: # video
170
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
171
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
172
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
173
+ else: # stream
174
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
175
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
176
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
177
+ vid_writer[i].write(im0)
178
+
179
+ # Print time (inference-only)
180
+ LOGGER.info(f"{s}{dt[1].dt * 1E3:.1f}ms")
181
+
182
+ # Print results
183
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
184
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
185
+ if save_txt or save_img:
186
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
187
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
188
+ if update:
189
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
190
+
191
+
192
+ def parse_opt():
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model path(s)')
195
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
196
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
197
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[224], help='inference size h,w')
198
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
199
+ parser.add_argument('--view-img', action='store_true', help='show results')
200
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
201
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
202
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
203
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
204
+ parser.add_argument('--update', action='store_true', help='update all models')
205
+ parser.add_argument('--project', default=ROOT / 'runs/predict-cls', help='save results to project/name')
206
+ parser.add_argument('--name', default='exp', help='save results to project/name')
207
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
208
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
209
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
210
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
211
+ opt = parser.parse_args()
212
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
213
+ print_args(vars(opt))
214
+ return opt
215
+
216
+
217
+ def main(opt):
218
+ check_requirements(exclude=('tensorboard', 'thop'))
219
+ run(**vars(opt))
220
+
221
+
222
+ if __name__ == "__main__":
223
+ opt = parse_opt()
224
+ main(opt)
yolov9/classify/train.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Train a YOLOv5 classifier model on a classification dataset
4
+
5
+ Usage - Single-GPU training:
6
+ $ python classify/train.py --model yolov5s-cls.pt --data imagenette160 --epochs 5 --img 224
7
+
8
+ Usage - Multi-GPU DDP training:
9
+ $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3
10
+
11
+ Datasets: --data mnist, fashion-mnist, cifar10, cifar100, imagenette, imagewoof, imagenet, or 'path/to/data'
12
+ YOLOv5-cls models: --model yolov5n-cls.pt, yolov5s-cls.pt, yolov5m-cls.pt, yolov5l-cls.pt, yolov5x-cls.pt
13
+ Torchvision models: --model resnet50, efficientnet_b0, etc. See https://pytorch.org/vision/stable/models.html
14
+ """
15
+
16
+ import argparse
17
+ import os
18
+ import subprocess
19
+ import sys
20
+ import time
21
+ from copy import deepcopy
22
+ from datetime import datetime
23
+ from pathlib import Path
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ import torch.hub as hub
28
+ import torch.optim.lr_scheduler as lr_scheduler
29
+ import torchvision
30
+ from torch.cuda import amp
31
+ from tqdm import tqdm
32
+
33
+ FILE = Path(__file__).resolve()
34
+ ROOT = FILE.parents[1] # YOLOv5 root directory
35
+ if str(ROOT) not in sys.path:
36
+ sys.path.append(str(ROOT)) # add ROOT to PATH
37
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
38
+
39
+ from classify import val as validate
40
+ from models.experimental import attempt_load
41
+ from models.yolo import ClassificationModel, DetectionModel
42
+ from utils.dataloaders import create_classification_dataloader
43
+ from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_info, check_git_status,
44
+ check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save)
45
+ from utils.loggers import GenericLogger
46
+ from utils.plots import imshow_cls
47
+ from utils.torch_utils import (ModelEMA, model_info, reshape_classifier_output, select_device, smart_DDP,
48
+ smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first)
49
+
50
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
51
+ RANK = int(os.getenv('RANK', -1))
52
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
53
+ GIT_INFO = check_git_info()
54
+
55
+
56
+ def train(opt, device):
57
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
58
+ save_dir, data, bs, epochs, nw, imgsz, pretrained = \
59
+ opt.save_dir, Path(opt.data), opt.batch_size, opt.epochs, min(os.cpu_count() - 1, opt.workers), \
60
+ opt.imgsz, str(opt.pretrained).lower() == 'true'
61
+ cuda = device.type != 'cpu'
62
+
63
+ # Directories
64
+ wdir = save_dir / 'weights'
65
+ wdir.mkdir(parents=True, exist_ok=True) # make dir
66
+ last, best = wdir / 'last.pt', wdir / 'best.pt'
67
+
68
+ # Save run settings
69
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
70
+
71
+ # Logger
72
+ logger = GenericLogger(opt=opt, console_logger=LOGGER) if RANK in {-1, 0} else None
73
+
74
+ # Download Dataset
75
+ with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
76
+ data_dir = data if data.is_dir() else (DATASETS_DIR / data)
77
+ if not data_dir.is_dir():
78
+ LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
79
+ t = time.time()
80
+ if str(data) == 'imagenet':
81
+ subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
82
+ else:
83
+ url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip'
84
+ download(url, dir=data_dir.parent)
85
+ s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
86
+ LOGGER.info(s)
87
+
88
+ # Dataloaders
89
+ nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
90
+ trainloader = create_classification_dataloader(path=data_dir / 'train',
91
+ imgsz=imgsz,
92
+ batch_size=bs // WORLD_SIZE,
93
+ augment=True,
94
+ cache=opt.cache,
95
+ rank=LOCAL_RANK,
96
+ workers=nw)
97
+
98
+ test_dir = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
99
+ if RANK in {-1, 0}:
100
+ testloader = create_classification_dataloader(path=test_dir,
101
+ imgsz=imgsz,
102
+ batch_size=bs // WORLD_SIZE * 2,
103
+ augment=False,
104
+ cache=opt.cache,
105
+ rank=-1,
106
+ workers=nw)
107
+
108
+ # Model
109
+ with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
110
+ if Path(opt.model).is_file() or opt.model.endswith('.pt'):
111
+ model = attempt_load(opt.model, device='cpu', fuse=False)
112
+ elif opt.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
113
+ model = torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None)
114
+ else:
115
+ m = hub.list('ultralytics/yolov5') # + hub.list('pytorch/vision') # models
116
+ raise ModuleNotFoundError(f'--model {opt.model} not found. Available models are: \n' + '\n'.join(m))
117
+ if isinstance(model, DetectionModel):
118
+ LOGGER.warning("WARNING ⚠️ pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
119
+ model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
120
+ reshape_classifier_output(model, nc) # update class count
121
+ for m in model.modules():
122
+ if not pretrained and hasattr(m, 'reset_parameters'):
123
+ m.reset_parameters()
124
+ if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
125
+ m.p = opt.dropout # set dropout
126
+ for p in model.parameters():
127
+ p.requires_grad = True # for training
128
+ model = model.to(device)
129
+
130
+ # Info
131
+ if RANK in {-1, 0}:
132
+ model.names = trainloader.dataset.classes # attach class names
133
+ model.transforms = testloader.dataset.torch_transforms # attach inference transforms
134
+ model_info(model)
135
+ if opt.verbose:
136
+ LOGGER.info(model)
137
+ images, labels = next(iter(trainloader))
138
+ file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
139
+ logger.log_images(file, name='Train Examples')
140
+ logger.log_graph(model, imgsz) # log model
141
+
142
+ # Optimizer
143
+ optimizer = smart_optimizer(model, opt.optimizer, opt.lr0, momentum=0.9, decay=opt.decay)
144
+
145
+ # Scheduler
146
+ lrf = 0.01 # final lr (fraction of lr0)
147
+ # lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
148
+ lf = lambda x: (1 - x / epochs) * (1 - lrf) + lrf # linear
149
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
150
+ # scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr0, total_steps=epochs, pct_start=0.1,
151
+ # final_div_factor=1 / 25 / lrf)
152
+
153
+ # EMA
154
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
155
+
156
+ # DDP mode
157
+ if cuda and RANK != -1:
158
+ model = smart_DDP(model)
159
+
160
+ # Train
161
+ t0 = time.time()
162
+ criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
163
+ best_fitness = 0.0
164
+ scaler = amp.GradScaler(enabled=cuda)
165
+ val = test_dir.stem # 'val' or 'test'
166
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} test\n'
167
+ f'Using {nw * WORLD_SIZE} dataloader workers\n'
168
+ f"Logging results to {colorstr('bold', save_dir)}\n"
169
+ f'Starting {opt.model} training on {data} dataset with {nc} classes for {epochs} epochs...\n\n'
170
+ f"{'Epoch':>10}{'GPU_mem':>10}{'train_loss':>12}{f'{val}_loss':>12}{'top1_acc':>12}{'top5_acc':>12}")
171
+ for epoch in range(epochs): # loop over the dataset multiple times
172
+ tloss, vloss, fitness = 0.0, 0.0, 0.0 # train loss, val loss, fitness
173
+ model.train()
174
+ if RANK != -1:
175
+ trainloader.sampler.set_epoch(epoch)
176
+ pbar = enumerate(trainloader)
177
+ if RANK in {-1, 0}:
178
+ pbar = tqdm(enumerate(trainloader), total=len(trainloader), bar_format=TQDM_BAR_FORMAT)
179
+ for i, (images, labels) in pbar: # progress bar
180
+ images, labels = images.to(device, non_blocking=True), labels.to(device)
181
+
182
+ # Forward
183
+ with amp.autocast(enabled=cuda): # stability issues when enabled
184
+ loss = criterion(model(images), labels)
185
+
186
+ # Backward
187
+ scaler.scale(loss).backward()
188
+
189
+ # Optimize
190
+ scaler.unscale_(optimizer) # unscale gradients
191
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
192
+ scaler.step(optimizer)
193
+ scaler.update()
194
+ optimizer.zero_grad()
195
+ if ema:
196
+ ema.update(model)
197
+
198
+ if RANK in {-1, 0}:
199
+ # Print
200
+ tloss = (tloss * i + loss.item()) / (i + 1) # update mean losses
201
+ mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
202
+ pbar.desc = f"{f'{epoch + 1}/{epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
203
+
204
+ # Test
205
+ if i == len(pbar) - 1: # last batch
206
+ top1, top5, vloss = validate.run(model=ema.ema,
207
+ dataloader=testloader,
208
+ criterion=criterion,
209
+ pbar=pbar) # test accuracy, loss
210
+ fitness = top1 # define fitness as top1 accuracy
211
+
212
+ # Scheduler
213
+ scheduler.step()
214
+
215
+ # Log metrics
216
+ if RANK in {-1, 0}:
217
+ # Best fitness
218
+ if fitness > best_fitness:
219
+ best_fitness = fitness
220
+
221
+ # Log
222
+ metrics = {
223
+ "train/loss": tloss,
224
+ f"{val}/loss": vloss,
225
+ "metrics/accuracy_top1": top1,
226
+ "metrics/accuracy_top5": top5,
227
+ "lr/0": optimizer.param_groups[0]['lr']} # learning rate
228
+ logger.log_metrics(metrics, epoch)
229
+
230
+ # Save model
231
+ final_epoch = epoch + 1 == epochs
232
+ if (not opt.nosave) or final_epoch:
233
+ ckpt = {
234
+ 'epoch': epoch,
235
+ 'best_fitness': best_fitness,
236
+ 'model': deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
237
+ 'ema': None, # deepcopy(ema.ema).half(),
238
+ 'updates': ema.updates,
239
+ 'optimizer': None, # optimizer.state_dict(),
240
+ 'opt': vars(opt),
241
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
242
+ 'date': datetime.now().isoformat()}
243
+
244
+ # Save last, best and delete
245
+ torch.save(ckpt, last)
246
+ if best_fitness == fitness:
247
+ torch.save(ckpt, best)
248
+ del ckpt
249
+
250
+ # Train complete
251
+ if RANK in {-1, 0} and final_epoch:
252
+ LOGGER.info(f'\nTraining complete ({(time.time() - t0) / 3600:.3f} hours)'
253
+ f"\nResults saved to {colorstr('bold', save_dir)}"
254
+ f"\nPredict: python classify/predict.py --weights {best} --source im.jpg"
255
+ f"\nValidate: python classify/val.py --weights {best} --data {data_dir}"
256
+ f"\nExport: python export.py --weights {best} --include onnx"
257
+ f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{best}')"
258
+ f"\nVisualize: https://netron.app\n")
259
+
260
+ # Plot examples
261
+ images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
262
+ pred = torch.max(ema.ema(images.to(device)), 1)[1]
263
+ file = imshow_cls(images, labels, pred, model.names, verbose=False, f=save_dir / 'test_images.jpg')
264
+
265
+ # Log results
266
+ meta = {"epochs": epochs, "top1_acc": best_fitness, "date": datetime.now().isoformat()}
267
+ logger.log_images(file, name='Test Examples (true-predicted)', epoch=epoch)
268
+ logger.log_model(best, epochs, metadata=meta)
269
+
270
+
271
+ def parse_opt(known=False):
272
+ parser = argparse.ArgumentParser()
273
+ parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
274
+ parser.add_argument('--data', type=str, default='imagenette160', help='cifar10, cifar100, mnist, imagenet, ...')
275
+ parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
276
+ parser.add_argument('--batch-size', type=int, default=64, help='total batch size for all GPUs')
277
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
278
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
279
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
280
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
281
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
282
+ parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
283
+ parser.add_argument('--name', default='exp', help='save to project/name')
284
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
285
+ parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
286
+ parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
287
+ parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
288
+ parser.add_argument('--decay', type=float, default=5e-5, help='weight decay')
289
+ parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
290
+ parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
291
+ parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
292
+ parser.add_argument('--verbose', action='store_true', help='Verbose mode')
293
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
294
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
295
+ return parser.parse_known_args()[0] if known else parser.parse_args()
296
+
297
+
298
+ def main(opt):
299
+ # Checks
300
+ if RANK in {-1, 0}:
301
+ print_args(vars(opt))
302
+ check_git_status()
303
+ check_requirements()
304
+
305
+ # DDP mode
306
+ device = select_device(opt.device, batch_size=opt.batch_size)
307
+ if LOCAL_RANK != -1:
308
+ assert opt.batch_size != -1, 'AutoBatch is coming soon for classification, please pass a valid --batch-size'
309
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
310
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
311
+ torch.cuda.set_device(LOCAL_RANK)
312
+ device = torch.device('cuda', LOCAL_RANK)
313
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
314
+
315
+ # Parameters
316
+ opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
317
+
318
+ # Train
319
+ train(opt, device)
320
+
321
+
322
+ def run(**kwargs):
323
+ # Usage: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m')
324
+ opt = parse_opt(True)
325
+ for k, v in kwargs.items():
326
+ setattr(opt, k, v)
327
+ main(opt)
328
+ return opt
329
+
330
+
331
+ if __name__ == "__main__":
332
+ opt = parse_opt()
333
+ main(opt)
yolov9/classify/val.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
+ """
3
+ Validate a trained YOLOv5 classification model on a classification dataset
4
+
5
+ Usage:
6
+ $ bash data/scripts/get_imagenet.sh --val # download ImageNet val split (6.3G, 50000 images)
7
+ $ python classify/val.py --weights yolov5m-cls.pt --data ../datasets/imagenet --img 224 # validate ImageNet
8
+
9
+ Usage - formats:
10
+ $ python classify/val.py --weights yolov5s-cls.pt # PyTorch
11
+ yolov5s-cls.torchscript # TorchScript
12
+ yolov5s-cls.onnx # ONNX Runtime or OpenCV DNN with --dnn
13
+ yolov5s-cls_openvino_model # OpenVINO
14
+ yolov5s-cls.engine # TensorRT
15
+ yolov5s-cls.mlmodel # CoreML (macOS-only)
16
+ yolov5s-cls_saved_model # TensorFlow SavedModel
17
+ yolov5s-cls.pb # TensorFlow GraphDef
18
+ yolov5s-cls.tflite # TensorFlow Lite
19
+ yolov5s-cls_edgetpu.tflite # TensorFlow Edge TPU
20
+ yolov5s-cls_paddle_model # PaddlePaddle
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ from tqdm import tqdm
30
+
31
+ FILE = Path(__file__).resolve()
32
+ ROOT = FILE.parents[1] # YOLOv5 root directory
33
+ if str(ROOT) not in sys.path:
34
+ sys.path.append(str(ROOT)) # add ROOT to PATH
35
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
36
+
37
+ from models.common import DetectMultiBackend
38
+ from utils.dataloaders import create_classification_dataloader
39
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_img_size, check_requirements, colorstr,
40
+ increment_path, print_args)
41
+ from utils.torch_utils import select_device, smart_inference_mode
42
+
43
+
44
+ @smart_inference_mode()
45
+ def run(
46
+ data=ROOT / '../datasets/mnist', # dataset dir
47
+ weights=ROOT / 'yolov5s-cls.pt', # model.pt path(s)
48
+ batch_size=128, # batch size
49
+ imgsz=224, # inference size (pixels)
50
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
51
+ workers=8, # max dataloader workers (per RANK in DDP mode)
52
+ verbose=False, # verbose output
53
+ project=ROOT / 'runs/val-cls', # save to project/name
54
+ name='exp', # save to project/name
55
+ exist_ok=False, # existing project/name ok, do not increment
56
+ half=False, # use FP16 half-precision inference
57
+ dnn=False, # use OpenCV DNN for ONNX inference
58
+ model=None,
59
+ dataloader=None,
60
+ criterion=None,
61
+ pbar=None,
62
+ ):
63
+ # Initialize/load model and set device
64
+ training = model is not None
65
+ if training: # called by train.py
66
+ device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
67
+ half &= device.type != 'cpu' # half precision only supported on CUDA
68
+ model.half() if half else model.float()
69
+ else: # called directly
70
+ device = select_device(device, batch_size=batch_size)
71
+
72
+ # Directories
73
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
74
+ save_dir.mkdir(parents=True, exist_ok=True) # make dir
75
+
76
+ # Load model
77
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, fp16=half)
78
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
79
+ imgsz = check_img_size(imgsz, s=stride) # check image size
80
+ half = model.fp16 # FP16 supported on limited backends with CUDA
81
+ if engine:
82
+ batch_size = model.batch_size
83
+ else:
84
+ device = model.device
85
+ if not (pt or jit):
86
+ batch_size = 1 # export.py models default to batch-size 1
87
+ LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
88
+
89
+ # Dataloader
90
+ data = Path(data)
91
+ test_dir = data / 'test' if (data / 'test').exists() else data / 'val' # data/test or data/val
92
+ dataloader = create_classification_dataloader(path=test_dir,
93
+ imgsz=imgsz,
94
+ batch_size=batch_size,
95
+ augment=False,
96
+ rank=-1,
97
+ workers=workers)
98
+
99
+ model.eval()
100
+ pred, targets, loss, dt = [], [], 0, (Profile(), Profile(), Profile())
101
+ n = len(dataloader) # number of batches
102
+ action = 'validating' if dataloader.dataset.root.stem == 'val' else 'testing'
103
+ desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}"
104
+ bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0)
105
+ with torch.cuda.amp.autocast(enabled=device.type != 'cpu'):
106
+ for images, labels in bar:
107
+ with dt[0]:
108
+ images, labels = images.to(device, non_blocking=True), labels.to(device)
109
+
110
+ with dt[1]:
111
+ y = model(images)
112
+
113
+ with dt[2]:
114
+ pred.append(y.argsort(1, descending=True)[:, :5])
115
+ targets.append(labels)
116
+ if criterion:
117
+ loss += criterion(y, labels)
118
+
119
+ loss /= n
120
+ pred, targets = torch.cat(pred), torch.cat(targets)
121
+ correct = (targets[:, None] == pred).float()
122
+ acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
123
+ top1, top5 = acc.mean(0).tolist()
124
+
125
+ if pbar:
126
+ pbar.desc = f"{pbar.desc[:-36]}{loss:>12.3g}{top1:>12.3g}{top5:>12.3g}"
127
+ if verbose: # all classes
128
+ LOGGER.info(f"{'Class':>24}{'Images':>12}{'top1_acc':>12}{'top5_acc':>12}")
129
+ LOGGER.info(f"{'all':>24}{targets.shape[0]:>12}{top1:>12.3g}{top5:>12.3g}")
130
+ for i, c in model.names.items():
131
+ aci = acc[targets == i]
132
+ top1i, top5i = aci.mean(0).tolist()
133
+ LOGGER.info(f"{c:>24}{aci.shape[0]:>12}{top1i:>12.3g}{top5i:>12.3g}")
134
+
135
+ # Print results
136
+ t = tuple(x.t / len(dataloader.dataset.samples) * 1E3 for x in dt) # speeds per image
137
+ shape = (1, 3, imgsz, imgsz)
138
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms post-process per image at shape {shape}' % t)
139
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
140
+
141
+ return top1, top5, loss
142
+
143
+
144
+ def parse_opt():
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('--data', type=str, default=ROOT / '../datasets/mnist', help='dataset path')
147
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s-cls.pt', help='model.pt path(s)')
148
+ parser.add_argument('--batch-size', type=int, default=128, help='batch size')
149
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='inference size (pixels)')
150
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
151
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
152
+ parser.add_argument('--verbose', nargs='?', const=True, default=True, help='verbose output')
153
+ parser.add_argument('--project', default=ROOT / 'runs/val-cls', help='save to project/name')
154
+ parser.add_argument('--name', default='exp', help='save to project/name')
155
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
156
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
157
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
158
+ opt = parser.parse_args()
159
+ print_args(vars(opt))
160
+ return opt
161
+
162
+
163
+ def main(opt):
164
+ check_requirements(exclude=('tensorboard', 'thop'))
165
+ run(**vars(opt))
166
+
167
+
168
+ if __name__ == "__main__":
169
+ opt = parse_opt()
170
+ main(opt)
yolov9/detect.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
19
+ from utils.plots import Annotator, colors, save_one_box
20
+ from utils.torch_utils import select_device, smart_inference_mode
21
+
22
+
23
+ @smart_inference_mode()
24
+ def run(
25
+ weights=ROOT / 'yolo.pt', # model path or triton URL
26
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
27
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
28
+ imgsz=(640, 640), # inference size (height, width)
29
+ conf_thres=0.25, # confidence threshold
30
+ iou_thres=0.45, # NMS IOU threshold
31
+ max_det=1000, # maximum detections per image
32
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
33
+ view_img=False, # show results
34
+ save_txt=False, # save results to *.txt
35
+ save_conf=False, # save confidences in --save-txt labels
36
+ save_crop=False, # save cropped prediction boxes
37
+ nosave=False, # do not save images/videos
38
+ classes=None, # filter by class: --class 0, or --class 0 2 3
39
+ agnostic_nms=False, # class-agnostic NMS
40
+ augment=False, # augmented inference
41
+ visualize=False, # visualize features
42
+ update=False, # update all models
43
+ project=ROOT / 'runs/detect', # save results to project/name
44
+ name='exp', # save results to project/name
45
+ exist_ok=False, # existing project/name ok, do not increment
46
+ line_thickness=3, # bounding box thickness (pixels)
47
+ hide_labels=False, # hide labels
48
+ hide_conf=False, # hide confidences
49
+ half=False, # use FP16 half-precision inference
50
+ dnn=False, # use OpenCV DNN for ONNX inference
51
+ vid_stride=1, # video frame-rate stride
52
+ ):
53
+ source = str(source)
54
+ save_img = not nosave and not source.endswith('.txt') # save inference images
55
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
56
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
57
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
58
+ screenshot = source.lower().startswith('screen')
59
+ if is_url and is_file:
60
+ source = check_file(source) # download
61
+
62
+ # Directories
63
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
64
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
65
+
66
+ # Load model
67
+ device = select_device(device)
68
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
69
+ stride, names, pt = model.stride, model.names, model.pt
70
+ imgsz = check_img_size(imgsz, s=stride) # check image size
71
+
72
+ # Dataloader
73
+ bs = 1 # batch_size
74
+ if webcam:
75
+ view_img = check_imshow(warn=True)
76
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
77
+ bs = len(dataset)
78
+ elif screenshot:
79
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
80
+ else:
81
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
82
+ vid_path, vid_writer = [None] * bs, [None] * bs
83
+
84
+ # Run inference
85
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
86
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
87
+ for path, im, im0s, vid_cap, s in dataset:
88
+ with dt[0]:
89
+ im = torch.from_numpy(im).to(model.device)
90
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
91
+ im /= 255 # 0 - 255 to 0.0 - 1.0
92
+ if len(im.shape) == 3:
93
+ im = im[None] # expand for batch dim
94
+
95
+ # Inference
96
+ with dt[1]:
97
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
98
+ pred = model(im, augment=augment, visualize=visualize)
99
+
100
+ # NMS
101
+ with dt[2]:
102
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
103
+
104
+ # Second-stage classifier (optional)
105
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
106
+
107
+ # Process predictions
108
+ for i, det in enumerate(pred): # per image
109
+ seen += 1
110
+ if webcam: # batch_size >= 1
111
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
112
+ s += f'{i}: '
113
+ else:
114
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
115
+
116
+ p = Path(p) # to Path
117
+ save_path = str(save_dir / p.name) # im.jpg
118
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
119
+ s += '%gx%g ' % im.shape[2:] # print string
120
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
121
+ imc = im0.copy() if save_crop else im0 # for save_crop
122
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
123
+ if len(det):
124
+ # Rescale boxes from img_size to im0 size
125
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
126
+
127
+ # Print results
128
+ for c in det[:, 5].unique():
129
+ n = (det[:, 5] == c).sum() # detections per class
130
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
131
+
132
+ # Write results
133
+ for *xyxy, conf, cls in reversed(det):
134
+ if save_txt: # Write to file
135
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
136
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
137
+ with open(f'{txt_path}.txt', 'a') as f:
138
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
139
+
140
+ if save_img or save_crop or view_img: # Add bbox to image
141
+ c = int(cls) # integer class
142
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
143
+ annotator.box_label(xyxy, label, color=colors(c, True))
144
+ if save_crop:
145
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
146
+
147
+ # Stream results
148
+ im0 = annotator.result()
149
+ if view_img:
150
+ if platform.system() == 'Linux' and p not in windows:
151
+ windows.append(p)
152
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
153
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
154
+ cv2.imshow(str(p), im0)
155
+ cv2.waitKey(1) # 1 millisecond
156
+
157
+ # Save results (image with detections)
158
+ if save_img:
159
+ if dataset.mode == 'image':
160
+ cv2.imwrite(save_path, im0)
161
+ else: # 'video' or 'stream'
162
+ if vid_path[i] != save_path: # new video
163
+ vid_path[i] = save_path
164
+ if isinstance(vid_writer[i], cv2.VideoWriter):
165
+ vid_writer[i].release() # release previous video writer
166
+ if vid_cap: # video
167
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
168
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
169
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
170
+ else: # stream
171
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
172
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
173
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
174
+ vid_writer[i].write(im0)
175
+
176
+ # Print time (inference-only)
177
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
178
+
179
+ # Print results
180
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
181
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
182
+ if save_txt or save_img:
183
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
184
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
185
+ if update:
186
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
187
+
188
+
189
+ def parse_opt():
190
+ parser = argparse.ArgumentParser()
191
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model path or triton URL')
192
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
193
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
194
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
195
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
196
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
197
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
198
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
199
+ parser.add_argument('--view-img', action='store_true', help='show results')
200
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
201
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
202
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
203
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
204
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
205
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
206
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
207
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
208
+ parser.add_argument('--update', action='store_true', help='update all models')
209
+ parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
210
+ parser.add_argument('--name', default='exp', help='save results to project/name')
211
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
212
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
213
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
214
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
215
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
216
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
217
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
218
+ opt = parser.parse_args()
219
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
220
+ print_args(vars(opt))
221
+ return opt
222
+
223
+
224
+ def main(opt):
225
+ # check_requirements(exclude=('tensorboard', 'thop'))
226
+ run(**vars(opt))
227
+
228
+
229
+ if __name__ == "__main__":
230
+ opt = parse_opt()
231
+ main(opt)
yolov9/detect_dual.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[0] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
19
+ from utils.plots import Annotator, colors, save_one_box
20
+ from utils.torch_utils import select_device, smart_inference_mode
21
+
22
+
23
+ @smart_inference_mode()
24
+ def run(
25
+ weights=ROOT / 'yolo.pt', # model path or triton URL
26
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
27
+ data=ROOT / 'data/coco.yaml', # dataset.yaml path
28
+ imgsz=(640, 640), # inference size (height, width)
29
+ conf_thres=0.25, # confidence threshold
30
+ iou_thres=0.45, # NMS IOU threshold
31
+ max_det=1000, # maximum detections per image
32
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
33
+ view_img=False, # show results
34
+ save_txt=False, # save results to *.txt
35
+ save_conf=False, # save confidences in --save-txt labels
36
+ save_crop=False, # save cropped prediction boxes
37
+ nosave=False, # do not save images/videos
38
+ classes=None, # filter by class: --class 0, or --class 0 2 3
39
+ agnostic_nms=False, # class-agnostic NMS
40
+ augment=False, # augmented inference
41
+ visualize=False, # visualize features
42
+ update=False, # update all models
43
+ project=ROOT / 'runs/detect', # save results to project/name
44
+ name='exp', # save results to project/name
45
+ exist_ok=False, # existing project/name ok, do not increment
46
+ line_thickness=3, # bounding box thickness (pixels)
47
+ hide_labels=False, # hide labels
48
+ hide_conf=False, # hide confidences
49
+ half=False, # use FP16 half-precision inference
50
+ dnn=False, # use OpenCV DNN for ONNX inference
51
+ vid_stride=1, # video frame-rate stride
52
+ ):
53
+ source = str(source)
54
+ save_img = not nosave and not source.endswith('.txt') # save inference images
55
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
56
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
57
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
58
+ screenshot = source.lower().startswith('screen')
59
+ if is_url and is_file:
60
+ source = check_file(source) # download
61
+
62
+ # Directories
63
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
64
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
65
+
66
+ # Load model
67
+ device = select_device(device)
68
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
69
+ stride, names, pt = model.stride, model.names, model.pt
70
+ imgsz = check_img_size(imgsz, s=stride) # check image size
71
+
72
+ # Dataloader
73
+ bs = 1 # batch_size
74
+ if webcam:
75
+ view_img = check_imshow(warn=True)
76
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
77
+ bs = len(dataset)
78
+ elif screenshot:
79
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
80
+ else:
81
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
82
+ vid_path, vid_writer = [None] * bs, [None] * bs
83
+
84
+ # Run inference
85
+ model.warmup(imgsz=(1 if pt or model.triton else bs, 3, *imgsz)) # warmup
86
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
87
+ for path, im, im0s, vid_cap, s in dataset:
88
+ with dt[0]:
89
+ im = torch.from_numpy(im).to(model.device)
90
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
91
+ im /= 255 # 0 - 255 to 0.0 - 1.0
92
+ if len(im.shape) == 3:
93
+ im = im[None] # expand for batch dim
94
+
95
+ # Inference
96
+ with dt[1]:
97
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
98
+ pred = model(im, augment=augment, visualize=visualize)
99
+ pred = pred[0][1]
100
+
101
+ # NMS
102
+ with dt[2]:
103
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
104
+
105
+ # Second-stage classifier (optional)
106
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
107
+
108
+ # Process predictions
109
+ for i, det in enumerate(pred): # per image
110
+ seen += 1
111
+ if webcam: # batch_size >= 1
112
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
113
+ s += f'{i}: '
114
+ else:
115
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
116
+
117
+ p = Path(p) # to Path
118
+ save_path = str(save_dir / p.name) # im.jpg
119
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
120
+ s += '%gx%g ' % im.shape[2:] # print string
121
+ gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
122
+ imc = im0.copy() if save_crop else im0 # for save_crop
123
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
124
+ if len(det):
125
+ # Rescale boxes from img_size to im0 size
126
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
127
+
128
+ # Print results
129
+ for c in det[:, 5].unique():
130
+ n = (det[:, 5] == c).sum() # detections per class
131
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
132
+
133
+ # Write results
134
+ for *xyxy, conf, cls in reversed(det):
135
+ if save_txt: # Write to file
136
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
137
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
138
+ with open(f'{txt_path}.txt', 'a') as f:
139
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
140
+
141
+ if save_img or save_crop or view_img: # Add bbox to image
142
+ c = int(cls) # integer class
143
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
144
+ annotator.box_label(xyxy, label, color=colors(c, True))
145
+ if save_crop:
146
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
147
+
148
+ # Stream results
149
+ im0 = annotator.result()
150
+ if view_img:
151
+ if platform.system() == 'Linux' and p not in windows:
152
+ windows.append(p)
153
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
154
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
155
+ cv2.imshow(str(p), im0)
156
+ cv2.waitKey(1) # 1 millisecond
157
+
158
+ # Save results (image with detections)
159
+ if save_img:
160
+ if dataset.mode == 'image':
161
+ cv2.imwrite(save_path, im0)
162
+ else: # 'video' or 'stream'
163
+ if vid_path[i] != save_path: # new video
164
+ vid_path[i] = save_path
165
+ if isinstance(vid_writer[i], cv2.VideoWriter):
166
+ vid_writer[i].release() # release previous video writer
167
+ if vid_cap: # video
168
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
169
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
170
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
171
+ else: # stream
172
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
173
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
174
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
175
+ vid_writer[i].write(im0)
176
+
177
+ # Print time (inference-only)
178
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
179
+
180
+ # Print results
181
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
182
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
183
+ if save_txt or save_img:
184
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
185
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
186
+ if update:
187
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
188
+
189
+
190
+ def parse_opt():
191
+ parser = argparse.ArgumentParser()
192
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model path or triton URL')
193
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
194
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
195
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
196
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
197
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
198
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
199
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
200
+ parser.add_argument('--view-img', action='store_true', help='show results')
201
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
202
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
203
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
204
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
205
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
206
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
207
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
208
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
209
+ parser.add_argument('--update', action='store_true', help='update all models')
210
+ parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
211
+ parser.add_argument('--name', default='exp', help='save results to project/name')
212
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
213
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
214
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
215
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
216
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
217
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
218
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
219
+ opt = parser.parse_args()
220
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
221
+ print_args(vars(opt))
222
+ return opt
223
+
224
+
225
+ def main(opt):
226
+ # check_requirements(exclude=('tensorboard', 'thop'))
227
+ run(**vars(opt))
228
+
229
+
230
+ if __name__ == "__main__":
231
+ opt = parse_opt()
232
+ main(opt)
yolov9/export.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import contextlib
3
+ import json
4
+ import os
5
+ import platform
6
+ import re
7
+ import subprocess
8
+ import sys
9
+ import time
10
+ import warnings
11
+ from pathlib import Path
12
+
13
+ import pandas as pd
14
+ import torch
15
+ from torch.utils.mobile_optimizer import optimize_for_mobile
16
+
17
+ FILE = Path(__file__).resolve()
18
+ ROOT = FILE.parents[0] # YOLO root directory
19
+ if str(ROOT) not in sys.path:
20
+ sys.path.append(str(ROOT)) # add ROOT to PATH
21
+ if platform.system() != 'Windows':
22
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
23
+
24
+ from models.experimental import attempt_load, End2End
25
+ from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
26
+ from utils.dataloaders import LoadImages
27
+ from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
28
+ check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
29
+ from utils.torch_utils import select_device, smart_inference_mode
30
+
31
+ MACOS = platform.system() == 'Darwin' # macOS environment
32
+
33
+
34
+ def export_formats():
35
+ # YOLO export formats
36
+ x = [
37
+ ['PyTorch', '-', '.pt', True, True],
38
+ ['TorchScript', 'torchscript', '.torchscript', True, True],
39
+ ['ONNX', 'onnx', '.onnx', True, True],
40
+ ['ONNX END2END', 'onnx_end2end', '_end2end.onnx', True, True],
41
+ ['OpenVINO', 'openvino', '_openvino_model', True, False],
42
+ ['TensorRT', 'engine', '.engine', False, True],
43
+ ['CoreML', 'coreml', '.mlmodel', True, False],
44
+ ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
45
+ ['TensorFlow GraphDef', 'pb', '.pb', True, True],
46
+ ['TensorFlow Lite', 'tflite', '.tflite', True, False],
47
+ ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
48
+ ['TensorFlow.js', 'tfjs', '_web_model', False, False],
49
+ ['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
50
+ return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
51
+
52
+
53
+ def try_export(inner_func):
54
+ # YOLO export decorator, i..e @try_export
55
+ inner_args = get_default_args(inner_func)
56
+
57
+ def outer_func(*args, **kwargs):
58
+ prefix = inner_args['prefix']
59
+ try:
60
+ with Profile() as dt:
61
+ f, model = inner_func(*args, **kwargs)
62
+ LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
63
+ return f, model
64
+ except Exception as e:
65
+ LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
66
+ return None, None
67
+
68
+ return outer_func
69
+
70
+
71
+ @try_export
72
+ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
73
+ # YOLO TorchScript model export
74
+ LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
75
+ f = file.with_suffix('.torchscript')
76
+
77
+ ts = torch.jit.trace(model, im, strict=False)
78
+ d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
79
+ extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
80
+ if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
81
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
82
+ else:
83
+ ts.save(str(f), _extra_files=extra_files)
84
+ return f, None
85
+
86
+
87
+ @try_export
88
+ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
89
+ # YOLO ONNX export
90
+ check_requirements('onnx')
91
+ import onnx
92
+
93
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
94
+ f = file.with_suffix('.onnx')
95
+
96
+ output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
97
+ if dynamic:
98
+ dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
99
+ if isinstance(model, SegmentationModel):
100
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
101
+ dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
102
+ elif isinstance(model, DetectionModel):
103
+ dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
104
+
105
+ torch.onnx.export(
106
+ model.cpu() if dynamic else model, # --dynamic only compatible with cpu
107
+ im.cpu() if dynamic else im,
108
+ f,
109
+ verbose=False,
110
+ opset_version=opset,
111
+ do_constant_folding=True,
112
+ input_names=['images'],
113
+ output_names=output_names,
114
+ dynamic_axes=dynamic or None)
115
+
116
+ # Checks
117
+ model_onnx = onnx.load(f) # load onnx model
118
+ onnx.checker.check_model(model_onnx) # check onnx model
119
+
120
+ # Metadata
121
+ d = {'stride': int(max(model.stride)), 'names': model.names}
122
+ for k, v in d.items():
123
+ meta = model_onnx.metadata_props.add()
124
+ meta.key, meta.value = k, str(v)
125
+ onnx.save(model_onnx, f)
126
+
127
+ # Simplify
128
+ if simplify:
129
+ try:
130
+ cuda = torch.cuda.is_available()
131
+ check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
132
+ import onnxsim
133
+
134
+ LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
135
+ model_onnx, check = onnxsim.simplify(model_onnx)
136
+ assert check, 'assert check failed'
137
+ onnx.save(model_onnx, f)
138
+ except Exception as e:
139
+ LOGGER.info(f'{prefix} simplifier failure: {e}')
140
+ return f, model_onnx
141
+
142
+
143
+ @try_export
144
+ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, labels, prefix=colorstr('ONNX END2END:')):
145
+ # YOLO ONNX export
146
+ check_requirements('onnx')
147
+ import onnx
148
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
149
+ f = os.path.splitext(file)[0] + "-end2end.onnx"
150
+ batch_size = 'batch'
151
+
152
+ dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes
153
+
154
+ output_axes = {
155
+ 'num_dets': {0: 'batch'},
156
+ 'det_boxes': {0: 'batch'},
157
+ 'det_scores': {0: 'batch'},
158
+ 'det_classes': {0: 'batch'},
159
+ }
160
+ dynamic_axes.update(output_axes)
161
+ model = End2End(model, topk_all, iou_thres, conf_thres, None ,device, labels)
162
+
163
+ output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
164
+ shapes = [ batch_size, 1, batch_size, topk_all, 4,
165
+ batch_size, topk_all, batch_size, topk_all]
166
+
167
+ torch.onnx.export(model,
168
+ im,
169
+ f,
170
+ verbose=False,
171
+ export_params=True, # store the trained parameter weights inside the model file
172
+ opset_version=12,
173
+ do_constant_folding=True, # whether to execute constant folding for optimization
174
+ input_names=['images'],
175
+ output_names=output_names,
176
+ dynamic_axes=dynamic_axes)
177
+
178
+ # Checks
179
+ model_onnx = onnx.load(f) # load onnx model
180
+ onnx.checker.check_model(model_onnx) # check onnx model
181
+ for i in model_onnx.graph.output:
182
+ for j in i.type.tensor_type.shape.dim:
183
+ j.dim_param = str(shapes.pop(0))
184
+
185
+ if simplify:
186
+ try:
187
+ import onnxsim
188
+
189
+ print('\nStarting to simplify ONNX...')
190
+ model_onnx, check = onnxsim.simplify(model_onnx)
191
+ assert check, 'assert check failed'
192
+ except Exception as e:
193
+ print(f'Simplifier failure: {e}')
194
+
195
+ # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
196
+ onnx.save(model_onnx,f)
197
+ print('ONNX export success, saved as %s' % f)
198
+ return f, model_onnx
199
+
200
+
201
+ @try_export
202
+ def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
203
+ # YOLO OpenVINO export
204
+ check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
205
+ import openvino.inference_engine as ie
206
+
207
+ LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
208
+ f = str(file).replace('.pt', f'_openvino_model{os.sep}')
209
+
210
+ #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
211
+ #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {"--compress_to_fp16" if half else ""}"
212
+ half_arg = "--compress_to_fp16" if half else ""
213
+ cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {half_arg}"
214
+ subprocess.run(cmd.split(), check=True, env=os.environ) # export
215
+ yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
216
+ return f, None
217
+
218
+
219
+ @try_export
220
+ def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
221
+ # YOLO Paddle export
222
+ check_requirements(('paddlepaddle', 'x2paddle'))
223
+ import x2paddle
224
+ from x2paddle.convert import pytorch2paddle
225
+
226
+ LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
227
+ f = str(file).replace('.pt', f'_paddle_model{os.sep}')
228
+
229
+ pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
230
+ yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
231
+ return f, None
232
+
233
+
234
+ @try_export
235
+ def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
236
+ # YOLO CoreML export
237
+ check_requirements('coremltools')
238
+ import coremltools as ct
239
+
240
+ LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
241
+ f = file.with_suffix('.mlmodel')
242
+
243
+ ts = torch.jit.trace(model, im, strict=False) # TorchScript model
244
+ ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
245
+ bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
246
+ if bits < 32:
247
+ if MACOS: # quantization only supported on macOS
248
+ with warnings.catch_warnings():
249
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
250
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
251
+ else:
252
+ print(f'{prefix} quantization only supported on macOS, skipping...')
253
+ ct_model.save(f)
254
+ return f, ct_model
255
+
256
+
257
+ @try_export
258
+ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
259
+ # YOLO TensorRT export https://developer.nvidia.com/tensorrt
260
+ assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
261
+ try:
262
+ import tensorrt as trt
263
+ except Exception:
264
+ if platform.system() == 'Linux':
265
+ check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
266
+ import tensorrt as trt
267
+
268
+ if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
269
+ grid = model.model[-1].anchor_grid
270
+ model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
271
+ export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
272
+ model.model[-1].anchor_grid = grid
273
+ else: # TensorRT >= 8
274
+ check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
275
+ export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
276
+ onnx = file.with_suffix('.onnx')
277
+
278
+ LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
279
+ assert onnx.exists(), f'failed to export ONNX file: {onnx}'
280
+ f = file.with_suffix('.engine') # TensorRT engine file
281
+ logger = trt.Logger(trt.Logger.INFO)
282
+ if verbose:
283
+ logger.min_severity = trt.Logger.Severity.VERBOSE
284
+
285
+ builder = trt.Builder(logger)
286
+ config = builder.create_builder_config()
287
+ config.max_workspace_size = workspace * 1 << 30
288
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
289
+
290
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
291
+ network = builder.create_network(flag)
292
+ parser = trt.OnnxParser(network, logger)
293
+ if not parser.parse_from_file(str(onnx)):
294
+ raise RuntimeError(f'failed to load ONNX file: {onnx}')
295
+
296
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
297
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
298
+ for inp in inputs:
299
+ LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
300
+ for out in outputs:
301
+ LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
302
+
303
+ if dynamic:
304
+ if im.shape[0] <= 1:
305
+ LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
306
+ profile = builder.create_optimization_profile()
307
+ for inp in inputs:
308
+ profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
309
+ config.add_optimization_profile(profile)
310
+
311
+ LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
312
+ if builder.platform_has_fast_fp16 and half:
313
+ config.set_flag(trt.BuilderFlag.FP16)
314
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
315
+ t.write(engine.serialize())
316
+ return f, None
317
+
318
+
319
+ @try_export
320
+ def export_saved_model(model,
321
+ im,
322
+ file,
323
+ dynamic,
324
+ tf_nms=False,
325
+ agnostic_nms=False,
326
+ topk_per_class=100,
327
+ topk_all=100,
328
+ iou_thres=0.45,
329
+ conf_thres=0.25,
330
+ keras=False,
331
+ prefix=colorstr('TensorFlow SavedModel:')):
332
+ # YOLO TensorFlow SavedModel export
333
+ try:
334
+ import tensorflow as tf
335
+ except Exception:
336
+ check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
337
+ import tensorflow as tf
338
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
339
+
340
+ from models.tf import TFModel
341
+
342
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
343
+ f = str(file).replace('.pt', '_saved_model')
344
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
345
+
346
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
347
+ im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
348
+ _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
349
+ inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
350
+ outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
351
+ keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
352
+ keras_model.trainable = False
353
+ keras_model.summary()
354
+ if keras:
355
+ keras_model.save(f, save_format='tf')
356
+ else:
357
+ spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
358
+ m = tf.function(lambda x: keras_model(x)) # full model
359
+ m = m.get_concrete_function(spec)
360
+ frozen_func = convert_variables_to_constants_v2(m)
361
+ tfm = tf.Module()
362
+ tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
363
+ tfm.__call__(im)
364
+ tf.saved_model.save(tfm,
365
+ f,
366
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
367
+ tf.__version__, '2.6') else tf.saved_model.SaveOptions())
368
+ return f, keras_model
369
+
370
+
371
+ @try_export
372
+ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
373
+ # YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
374
+ import tensorflow as tf
375
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
376
+
377
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
378
+ f = file.with_suffix('.pb')
379
+
380
+ m = tf.function(lambda x: keras_model(x)) # full model
381
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
382
+ frozen_func = convert_variables_to_constants_v2(m)
383
+ frozen_func.graph.as_graph_def()
384
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
385
+ return f, None
386
+
387
+
388
+ @try_export
389
+ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
390
+ # YOLOv5 TensorFlow Lite export
391
+ import tensorflow as tf
392
+
393
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
394
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
395
+ f = str(file).replace('.pt', '-fp16.tflite')
396
+
397
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
398
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
399
+ converter.target_spec.supported_types = [tf.float16]
400
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
401
+ if int8:
402
+ from models.tf import representative_dataset_gen
403
+ dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
404
+ converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
405
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
406
+ converter.target_spec.supported_types = []
407
+ converter.inference_input_type = tf.uint8 # or tf.int8
408
+ converter.inference_output_type = tf.uint8 # or tf.int8
409
+ converter.experimental_new_quantizer = True
410
+ f = str(file).replace('.pt', '-int8.tflite')
411
+ if nms or agnostic_nms:
412
+ converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
413
+
414
+ tflite_model = converter.convert()
415
+ open(f, "wb").write(tflite_model)
416
+ return f, None
417
+
418
+
419
+ @try_export
420
+ def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
421
+ # YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
422
+ cmd = 'edgetpu_compiler --version'
423
+ help_url = 'https://coral.ai/docs/edgetpu/compiler/'
424
+ assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
425
+ if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
426
+ LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
427
+ sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
428
+ for c in (
429
+ 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
430
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
431
+ 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
432
+ subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
433
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
434
+
435
+ LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
436
+ f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
437
+ f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
438
+
439
+ cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
440
+ subprocess.run(cmd.split(), check=True)
441
+ return f, None
442
+
443
+
444
+ @try_export
445
+ def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
446
+ # YOLO TensorFlow.js export
447
+ check_requirements('tensorflowjs')
448
+ import tensorflowjs as tfjs
449
+
450
+ LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
451
+ f = str(file).replace('.pt', '_web_model') # js dir
452
+ f_pb = file.with_suffix('.pb') # *.pb path
453
+ f_json = f'{f}/model.json' # *.json path
454
+
455
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
456
+ f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
457
+ subprocess.run(cmd.split())
458
+
459
+ json = Path(f_json).read_text()
460
+ with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
461
+ subst = re.sub(
462
+ r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
463
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
464
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
465
+ r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
466
+ r'"Identity_1": {"name": "Identity_1"}, '
467
+ r'"Identity_2": {"name": "Identity_2"}, '
468
+ r'"Identity_3": {"name": "Identity_3"}}}', json)
469
+ j.write(subst)
470
+ return f, None
471
+
472
+
473
+ def add_tflite_metadata(file, metadata, num_outputs):
474
+ # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
475
+ with contextlib.suppress(ImportError):
476
+ # check_requirements('tflite_support')
477
+ from tflite_support import flatbuffers
478
+ from tflite_support import metadata as _metadata
479
+ from tflite_support import metadata_schema_py_generated as _metadata_fb
480
+
481
+ tmp_file = Path('/tmp/meta.txt')
482
+ with open(tmp_file, 'w') as meta_f:
483
+ meta_f.write(str(metadata))
484
+
485
+ model_meta = _metadata_fb.ModelMetadataT()
486
+ label_file = _metadata_fb.AssociatedFileT()
487
+ label_file.name = tmp_file.name
488
+ model_meta.associatedFiles = [label_file]
489
+
490
+ subgraph = _metadata_fb.SubGraphMetadataT()
491
+ subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
492
+ subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
493
+ model_meta.subgraphMetadata = [subgraph]
494
+
495
+ b = flatbuffers.Builder(0)
496
+ b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
497
+ metadata_buf = b.Output()
498
+
499
+ populator = _metadata.MetadataPopulator.with_model_file(file)
500
+ populator.load_metadata_buffer(metadata_buf)
501
+ populator.load_associated_files([str(tmp_file)])
502
+ populator.populate()
503
+ tmp_file.unlink()
504
+
505
+
506
+ @smart_inference_mode()
507
+ def run(
508
+ data=ROOT / 'data/coco.yaml', # 'dataset.yaml path'
509
+ weights=ROOT / 'yolo.pt', # weights path
510
+ imgsz=(640, 640), # image (height, width)
511
+ batch_size=1, # batch size
512
+ device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
513
+ include=('torchscript', 'onnx'), # include formats
514
+ half=False, # FP16 half-precision export
515
+ inplace=False, # set YOLO Detect() inplace=True
516
+ keras=False, # use Keras
517
+ optimize=False, # TorchScript: optimize for mobile
518
+ int8=False, # CoreML/TF INT8 quantization
519
+ dynamic=False, # ONNX/TF/TensorRT: dynamic axes
520
+ simplify=False, # ONNX: simplify model
521
+ opset=12, # ONNX: opset version
522
+ verbose=False, # TensorRT: verbose log
523
+ workspace=4, # TensorRT: workspace size (GB)
524
+ nms=False, # TF: add NMS to model
525
+ agnostic_nms=False, # TF: add agnostic NMS to model
526
+ topk_per_class=100, # TF.js NMS: topk per class to keep
527
+ topk_all=100, # TF.js NMS: topk for all classes to keep
528
+ iou_thres=0.45, # TF.js NMS: IoU threshold
529
+ conf_thres=0.25, # TF.js NMS: confidence threshold
530
+ ):
531
+ t = time.time()
532
+ include = [x.lower() for x in include] # to lowercase
533
+ fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
534
+ flags = [x in include for x in fmts]
535
+ assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
536
+ jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
537
+ file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
538
+
539
+ # Load PyTorch model
540
+ device = select_device(device)
541
+ if half:
542
+ assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
543
+ assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
544
+ model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
545
+
546
+ # Checks
547
+ imgsz *= 2 if len(imgsz) == 1 else 1 # expand
548
+ if optimize:
549
+ assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
550
+
551
+ # Input
552
+ gs = int(max(model.stride)) # grid size (max stride)
553
+ imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
554
+ im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
555
+
556
+ # Update model
557
+ model.eval()
558
+ for k, m in model.named_modules():
559
+ if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
560
+ m.inplace = inplace
561
+ m.dynamic = dynamic
562
+ m.export = True
563
+
564
+ for _ in range(2):
565
+ y = model(im) # dry runs
566
+ if half and not coreml:
567
+ im, model = im.half(), model.half() # to FP16
568
+ shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape
569
+ metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
570
+ LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
571
+
572
+ # Exports
573
+ f = [''] * len(fmts) # exported filenames
574
+ warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
575
+ if jit: # TorchScript
576
+ f[0], _ = export_torchscript(model, im, file, optimize)
577
+ if engine: # TensorRT required before ONNX
578
+ f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
579
+ if onnx or xml: # OpenVINO requires ONNX
580
+ f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
581
+ if onnx_end2end:
582
+ if isinstance(model, DetectionModel):
583
+ labels = model.names
584
+ f[2], _ = export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, len(labels))
585
+ else:
586
+ raise RuntimeError("The model is not a DetectionModel.")
587
+ if xml: # OpenVINO
588
+ f[3], _ = export_openvino(file, metadata, half)
589
+ if coreml: # CoreML
590
+ f[4], _ = export_coreml(model, im, file, int8, half)
591
+ if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
592
+ assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
593
+ assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
594
+ f[5], s_model = export_saved_model(model.cpu(),
595
+ im,
596
+ file,
597
+ dynamic,
598
+ tf_nms=nms or agnostic_nms or tfjs,
599
+ agnostic_nms=agnostic_nms or tfjs,
600
+ topk_per_class=topk_per_class,
601
+ topk_all=topk_all,
602
+ iou_thres=iou_thres,
603
+ conf_thres=conf_thres,
604
+ keras=keras)
605
+ if pb or tfjs: # pb prerequisite to tfjs
606
+ f[6], _ = export_pb(s_model, file)
607
+ if tflite or edgetpu:
608
+ f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
609
+ if edgetpu:
610
+ f[8], _ = export_edgetpu(file)
611
+ add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
612
+ if tfjs:
613
+ f[9], _ = export_tfjs(file)
614
+ if paddle: # PaddlePaddle
615
+ f[10], _ = export_paddle(model, im, file, metadata)
616
+
617
+ # Finish
618
+ f = [str(x) for x in f if x] # filter out '' and None
619
+ if any(f):
620
+ cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
621
+ dir = Path('segment' if seg else 'classify' if cls else '')
622
+ h = '--half' if half else '' # --half FP16 inference arg
623
+ s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \
624
+ "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''
625
+ if onnx_end2end:
626
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
627
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
628
+ f"\nVisualize: https://netron.app")
629
+ else:
630
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
631
+ f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
632
+ f"\nDetect: python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"
633
+ f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}"
634
+ f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}"
635
+ f"\nVisualize: https://netron.app")
636
+ return f # return list of exported files/dirs
637
+
638
+
639
+ def parse_opt():
640
+ parser = argparse.ArgumentParser()
641
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
642
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model.pt path(s)')
643
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
644
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
645
+ parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
646
+ parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
647
+ parser.add_argument('--inplace', action='store_true', help='set YOLO Detect() inplace=True')
648
+ parser.add_argument('--keras', action='store_true', help='TF: use Keras')
649
+ parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
650
+ parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
651
+ parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
652
+ parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
653
+ parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
654
+ parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
655
+ parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
656
+ parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
657
+ parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
658
+ parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
659
+ parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep')
660
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold')
661
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold')
662
+ parser.add_argument(
663
+ '--include',
664
+ nargs='+',
665
+ default=['torchscript'],
666
+ help='torchscript, onnx, onnx_end2end, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')
667
+ opt = parser.parse_args()
668
+
669
+ if 'onnx_end2end' in opt.include:
670
+ opt.simplify = True
671
+ opt.dynamic = True
672
+ opt.inplace = True
673
+ opt.half = False
674
+
675
+ print_args(vars(opt))
676
+ return opt
677
+
678
+
679
+ def main(opt):
680
+ for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
681
+ run(**vars(opt))
682
+
683
+
684
+ if __name__ == "__main__":
685
+ opt = parse_opt()
686
+ main(opt)
yolov9/figure/horses_prediction.jpg ADDED
yolov9/figure/multitask.png ADDED

Git LFS Details

  • SHA256: b7c83ee5db84a3760a0f854e5d70ed0e2ca1cc0f5ef5ff8a88e87d525e87eee1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.29 MB
yolov9/figure/performance.png ADDED

Git LFS Details

  • SHA256: 85f6432cf6a4a1079537d5525642722e4a09d84955e681badb9a5c7b70096baa
  • Pointer size: 131 Bytes
  • Size of remote file: 356 kB
yolov9/hubconf.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
5
+ """Creates or loads a YOLO model
6
+
7
+ Arguments:
8
+ name (str): model name 'yolov3' or path 'path/to/best.pt'
9
+ pretrained (bool): load pretrained weights into the model
10
+ channels (int): number of input channels
11
+ classes (int): number of model classes
12
+ autoshape (bool): apply YOLO .autoshape() wrapper to model
13
+ verbose (bool): print all information to screen
14
+ device (str, torch.device, None): device to use for model parameters
15
+
16
+ Returns:
17
+ YOLO model
18
+ """
19
+ from pathlib import Path
20
+
21
+ from models.common import AutoShape, DetectMultiBackend
22
+ from models.experimental import attempt_load
23
+ from models.yolo import ClassificationModel, DetectionModel, SegmentationModel
24
+ from utils.downloads import attempt_download
25
+ from utils.general import LOGGER, check_requirements, intersect_dicts, logging
26
+ from utils.torch_utils import select_device
27
+
28
+ if not verbose:
29
+ LOGGER.setLevel(logging.WARNING)
30
+ check_requirements(exclude=('opencv-python', 'tensorboard', 'thop'))
31
+ name = Path(name)
32
+ path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path
33
+ try:
34
+ device = select_device(device)
35
+ if pretrained and channels == 3 and classes == 80:
36
+ try:
37
+ model = DetectMultiBackend(path, device=device, fuse=autoshape) # detection model
38
+ if autoshape:
39
+ if model.pt and isinstance(model.model, ClassificationModel):
40
+ LOGGER.warning('WARNING ⚠️ YOLO ClassificationModel is not yet AutoShape compatible. '
41
+ 'You must pass torch tensors in BCHW to this model, i.e. shape(1,3,224,224).')
42
+ elif model.pt and isinstance(model.model, SegmentationModel):
43
+ LOGGER.warning('WARNING ⚠️ YOLO SegmentationModel is not yet AutoShape compatible. '
44
+ 'You will not be able to run inference with this model.')
45
+ else:
46
+ model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
47
+ except Exception:
48
+ model = attempt_load(path, device=device, fuse=False) # arbitrary model
49
+ else:
50
+ cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path
51
+ model = DetectionModel(cfg, channels, classes) # create model
52
+ if pretrained:
53
+ ckpt = torch.load(attempt_download(path), map_location=device) # load
54
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
55
+ csd = intersect_dicts(csd, model.state_dict(), exclude=['anchors']) # intersect
56
+ model.load_state_dict(csd, strict=False) # load
57
+ if len(ckpt['model'].names) == classes:
58
+ model.names = ckpt['model'].names # set class names attribute
59
+ if not verbose:
60
+ LOGGER.setLevel(logging.INFO) # reset to default
61
+ return model.to(device)
62
+
63
+ except Exception as e:
64
+ help_url = 'https://github.com/ultralytics/yolov5/issues/36'
65
+ s = f'{e}. Cache may be out of date, try `force_reload=True` or see {help_url} for help.'
66
+ raise Exception(s) from e
67
+
68
+
69
+ def custom(path='path/to/model.pt', autoshape=True, _verbose=True, device=None):
70
+ # YOLO custom or local model
71
+ return _create(path, autoshape=autoshape, verbose=_verbose, device=device)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ import argparse
76
+ from pathlib import Path
77
+
78
+ import numpy as np
79
+ from PIL import Image
80
+
81
+ from utils.general import cv2, print_args
82
+
83
+ # Argparser
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('--model', type=str, default='yolo', help='model name')
86
+ opt = parser.parse_args()
87
+ print_args(vars(opt))
88
+
89
+ # Model
90
+ model = _create(name=opt.model, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True)
91
+ # model = custom(path='path/to/model.pt') # custom
92
+
93
+ # Images
94
+ imgs = [
95
+ 'data/images/zidane.jpg', # filename
96
+ Path('data/images/zidane.jpg'), # Path
97
+ 'https://ultralytics.com/images/zidane.jpg', # URI
98
+ cv2.imread('data/images/bus.jpg')[:, :, ::-1], # OpenCV
99
+ Image.open('data/images/bus.jpg'), # PIL
100
+ np.zeros((320, 640, 3))] # numpy
101
+
102
+ # Inference
103
+ results = model(imgs, size=320) # batched inference
104
+
105
+ # Results
106
+ results.print()
107
+ results.save()
yolov9/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # init
yolov9/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (249 Bytes). View file
 
yolov9/models/__pycache__/common.cpython-311.pyc ADDED
Binary file (110 kB). View file
 
yolov9/models/__pycache__/experimental.cpython-311.pyc ADDED
Binary file (20.2 kB). View file
 
yolov9/models/__pycache__/yolo.cpython-311.pyc ADDED
Binary file (86.3 kB). View file
 
yolov9/models/common.py ADDED
@@ -0,0 +1,1233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import contextlib
3
+ import json
4
+ import math
5
+ import platform
6
+ import warnings
7
+ import zipfile
8
+ from collections import OrderedDict, namedtuple
9
+ from copy import copy
10
+ from pathlib import Path
11
+ from urllib.parse import urlparse
12
+
13
+ from typing import Optional
14
+
15
+ import cv2
16
+ import numpy as np
17
+ import pandas as pd
18
+ import requests
19
+ import torch
20
+ import torch.nn as nn
21
+ from IPython.display import display
22
+ from PIL import Image
23
+ from torch.cuda import amp
24
+
25
+ from utils import TryExcept
26
+ from utils.dataloaders import exif_transpose, letterbox
27
+ from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
28
+ increment_path, is_notebook, make_divisible, non_max_suppression, scale_boxes,
29
+ xywh2xyxy, xyxy2xywh, yaml_load)
30
+ from utils.plots import Annotator, colors, save_one_box
31
+ from utils.torch_utils import copy_attr, smart_inference_mode
32
+
33
+
34
+ def autopad(k, p=None, d=1): # kernel, padding, dilation
35
+ # Pad to 'same' shape outputs
36
+ if d > 1:
37
+ k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
38
+ if p is None:
39
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
40
+ return p
41
+
42
+
43
+ class Conv(nn.Module):
44
+ # Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)
45
+ default_act = nn.SiLU() # default activation
46
+
47
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
48
+ super().__init__()
49
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
50
+ self.bn = nn.BatchNorm2d(c2)
51
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ return self.act(self.bn(self.conv(x)))
55
+
56
+ def forward_fuse(self, x):
57
+ return self.act(self.conv(x))
58
+
59
+
60
+ class AConv(nn.Module):
61
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
62
+ super().__init__()
63
+ self.cv1 = Conv(c1, c2, 3, 2, 1)
64
+
65
+ def forward(self, x):
66
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
67
+ return self.cv1(x)
68
+
69
+
70
+ class ADown(nn.Module):
71
+ def __init__(self, c1, c2): # ch_in, ch_out, shortcut, kernels, groups, expand
72
+ super().__init__()
73
+ self.c = c2 // 2
74
+ self.cv1 = Conv(c1 // 2, self.c, 3, 2, 1)
75
+ self.cv2 = Conv(c1 // 2, self.c, 1, 1, 0)
76
+
77
+ def forward(self, x):
78
+ x = torch.nn.functional.avg_pool2d(x, 2, 1, 0, False, True)
79
+ x1,x2 = x.chunk(2, 1)
80
+ x1 = self.cv1(x1)
81
+ x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
82
+ x2 = self.cv2(x2)
83
+ return torch.cat((x1, x2), 1)
84
+
85
+
86
+ class RepConvN(nn.Module):
87
+ """RepConv is a basic rep-style block, including training and deploy status
88
+ This code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
89
+ """
90
+ default_act = nn.SiLU() # default activation
91
+
92
+ def __init__(self, c1, c2, k=3, s=1, p=1, g=1, d=1, act=True, bn=False, deploy=False):
93
+ super().__init__()
94
+ assert k == 3 and p == 1
95
+ self.g = g
96
+ self.c1 = c1
97
+ self.c2 = c2
98
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
99
+
100
+ self.bn = None
101
+ self.conv1 = Conv(c1, c2, k, s, p=p, g=g, act=False)
102
+ self.conv2 = Conv(c1, c2, 1, s, p=(p - k // 2), g=g, act=False)
103
+
104
+ def forward_fuse(self, x):
105
+ """Forward process"""
106
+ return self.act(self.conv(x))
107
+
108
+ def forward(self, x):
109
+ """Forward process"""
110
+ id_out = 0 if self.bn is None else self.bn(x)
111
+ return self.act(self.conv1(x) + self.conv2(x) + id_out)
112
+
113
+ def get_equivalent_kernel_bias(self):
114
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
115
+ kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
116
+ kernelid, biasid = self._fuse_bn_tensor(self.bn)
117
+ return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
118
+
119
+ def _avg_to_3x3_tensor(self, avgp):
120
+ channels = self.c1
121
+ groups = self.g
122
+ kernel_size = avgp.kernel_size
123
+ input_dim = channels // groups
124
+ k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
125
+ k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
126
+ return k
127
+
128
+ def _pad_1x1_to_3x3_tensor(self, kernel1x1):
129
+ if kernel1x1 is None:
130
+ return 0
131
+ else:
132
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
133
+
134
+ def _fuse_bn_tensor(self, branch):
135
+ if branch is None:
136
+ return 0, 0
137
+ if isinstance(branch, Conv):
138
+ kernel = branch.conv.weight
139
+ running_mean = branch.bn.running_mean
140
+ running_var = branch.bn.running_var
141
+ gamma = branch.bn.weight
142
+ beta = branch.bn.bias
143
+ eps = branch.bn.eps
144
+ elif isinstance(branch, nn.BatchNorm2d):
145
+ if not hasattr(self, 'id_tensor'):
146
+ input_dim = self.c1 // self.g
147
+ kernel_value = np.zeros((self.c1, input_dim, 3, 3), dtype=np.float32)
148
+ for i in range(self.c1):
149
+ kernel_value[i, i % input_dim, 1, 1] = 1
150
+ self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
151
+ kernel = self.id_tensor
152
+ running_mean = branch.running_mean
153
+ running_var = branch.running_var
154
+ gamma = branch.weight
155
+ beta = branch.bias
156
+ eps = branch.eps
157
+ std = (running_var + eps).sqrt()
158
+ t = (gamma / std).reshape(-1, 1, 1, 1)
159
+ return kernel * t, beta - running_mean * gamma / std
160
+
161
+ def fuse_convs(self):
162
+ if hasattr(self, 'conv'):
163
+ return
164
+ kernel, bias = self.get_equivalent_kernel_bias()
165
+ self.conv = nn.Conv2d(in_channels=self.conv1.conv.in_channels,
166
+ out_channels=self.conv1.conv.out_channels,
167
+ kernel_size=self.conv1.conv.kernel_size,
168
+ stride=self.conv1.conv.stride,
169
+ padding=self.conv1.conv.padding,
170
+ dilation=self.conv1.conv.dilation,
171
+ groups=self.conv1.conv.groups,
172
+ bias=True).requires_grad_(False)
173
+ self.conv.weight.data = kernel
174
+ self.conv.bias.data = bias
175
+ for para in self.parameters():
176
+ para.detach_()
177
+ self.__delattr__('conv1')
178
+ self.__delattr__('conv2')
179
+ if hasattr(self, 'nm'):
180
+ self.__delattr__('nm')
181
+ if hasattr(self, 'bn'):
182
+ self.__delattr__('bn')
183
+ if hasattr(self, 'id_tensor'):
184
+ self.__delattr__('id_tensor')
185
+
186
+
187
+ class SP(nn.Module):
188
+ def __init__(self, k=3, s=1):
189
+ super(SP, self).__init__()
190
+ self.m = nn.MaxPool2d(kernel_size=k, stride=s, padding=k // 2)
191
+
192
+ def forward(self, x):
193
+ return self.m(x)
194
+
195
+
196
+ class MP(nn.Module):
197
+ # Max pooling
198
+ def __init__(self, k=2):
199
+ super(MP, self).__init__()
200
+ self.m = nn.MaxPool2d(kernel_size=k, stride=k)
201
+
202
+ def forward(self, x):
203
+ return self.m(x)
204
+
205
+
206
+ class ConvTranspose(nn.Module):
207
+ # Convolution transpose 2d layer
208
+ default_act = nn.SiLU() # default activation
209
+
210
+ def __init__(self, c1, c2, k=2, s=2, p=0, bn=True, act=True):
211
+ super().__init__()
212
+ self.conv_transpose = nn.ConvTranspose2d(c1, c2, k, s, p, bias=not bn)
213
+ self.bn = nn.BatchNorm2d(c2) if bn else nn.Identity()
214
+ self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
215
+
216
+ def forward(self, x):
217
+ return self.act(self.bn(self.conv_transpose(x)))
218
+
219
+
220
+ class DWConv(Conv):
221
+ # Depth-wise convolution
222
+ def __init__(self, c1, c2, k=1, s=1, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
223
+ super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)
224
+
225
+
226
+ class DWConvTranspose2d(nn.ConvTranspose2d):
227
+ # Depth-wise transpose convolution
228
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
229
+ super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
230
+
231
+
232
+ class DFL(nn.Module):
233
+ # DFL module
234
+ def __init__(self, c1=17):
235
+ super().__init__()
236
+ self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
237
+ self.conv.weight.data[:] = nn.Parameter(torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1)) # / 120.0
238
+ self.c1 = c1
239
+ # self.bn = nn.BatchNorm2d(4)
240
+
241
+ def forward(self, x):
242
+ b, c, a = x.shape # batch, channels, anchors
243
+ return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)
244
+ # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a)
245
+
246
+
247
+ class BottleneckBase(nn.Module):
248
+ # Standard bottleneck
249
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(1, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
250
+ super().__init__()
251
+ c_ = int(c2 * e) # hidden channels
252
+ self.cv1 = Conv(c1, c_, k[0], 1)
253
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
254
+ self.add = shortcut and c1 == c2
255
+
256
+ def forward(self, x):
257
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
258
+
259
+
260
+ class RBottleneckBase(nn.Module):
261
+ # Standard bottleneck
262
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
263
+ super().__init__()
264
+ c_ = int(c2 * e) # hidden channels
265
+ self.cv1 = Conv(c1, c_, k[0], 1)
266
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
267
+ self.add = shortcut and c1 == c2
268
+
269
+ def forward(self, x):
270
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
271
+
272
+
273
+ class RepNRBottleneckBase(nn.Module):
274
+ # Standard bottleneck
275
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 1), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
276
+ super().__init__()
277
+ c_ = int(c2 * e) # hidden channels
278
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
279
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
280
+ self.add = shortcut and c1 == c2
281
+
282
+ def forward(self, x):
283
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
284
+
285
+
286
+ class Bottleneck(nn.Module):
287
+ # Standard bottleneck
288
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
289
+ super().__init__()
290
+ c_ = int(c2 * e) # hidden channels
291
+ self.cv1 = Conv(c1, c_, k[0], 1)
292
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
293
+ self.add = shortcut and c1 == c2
294
+
295
+ def forward(self, x):
296
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
297
+
298
+
299
+ class RepNBottleneck(nn.Module):
300
+ # Standard bottleneck
301
+ def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
302
+ super().__init__()
303
+ c_ = int(c2 * e) # hidden channels
304
+ self.cv1 = RepConvN(c1, c_, k[0], 1)
305
+ self.cv2 = Conv(c_, c2, k[1], 1, g=g)
306
+ self.add = shortcut and c1 == c2
307
+
308
+ def forward(self, x):
309
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
310
+
311
+
312
+ class Res(nn.Module):
313
+ # ResNet bottleneck
314
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
315
+ super(Res, self).__init__()
316
+ c_ = int(c2 * e) # hidden channels
317
+ self.cv1 = Conv(c1, c_, 1, 1)
318
+ self.cv2 = Conv(c_, c_, 3, 1, g=g)
319
+ self.cv3 = Conv(c_, c2, 1, 1)
320
+ self.add = shortcut and c1 == c2
321
+
322
+ def forward(self, x):
323
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
324
+
325
+
326
+ class RepNRes(nn.Module):
327
+ # ResNet bottleneck
328
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
329
+ super(RepNRes, self).__init__()
330
+ c_ = int(c2 * e) # hidden channels
331
+ self.cv1 = Conv(c1, c_, 1, 1)
332
+ self.cv2 = RepConvN(c_, c_, 3, 1, g=g)
333
+ self.cv3 = Conv(c_, c2, 1, 1)
334
+ self.add = shortcut and c1 == c2
335
+
336
+ def forward(self, x):
337
+ return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
338
+
339
+
340
+ class BottleneckCSP(nn.Module):
341
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
342
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
343
+ super().__init__()
344
+ c_ = int(c2 * e) # hidden channels
345
+ self.cv1 = Conv(c1, c_, 1, 1)
346
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
347
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
348
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
349
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
350
+ self.act = nn.SiLU()
351
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
352
+
353
+ def forward(self, x):
354
+ y1 = self.cv3(self.m(self.cv1(x)))
355
+ y2 = self.cv2(x)
356
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))
357
+
358
+
359
+ class CSP(nn.Module):
360
+ # CSP Bottleneck with 3 convolutions
361
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
362
+ super().__init__()
363
+ c_ = int(c2 * e) # hidden channels
364
+ self.cv1 = Conv(c1, c_, 1, 1)
365
+ self.cv2 = Conv(c1, c_, 1, 1)
366
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
367
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
368
+
369
+ def forward(self, x):
370
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
371
+
372
+
373
+ class RepNCSP(nn.Module):
374
+ # CSP Bottleneck with 3 convolutions
375
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
376
+ super().__init__()
377
+ c_ = int(c2 * e) # hidden channels
378
+ self.cv1 = Conv(c1, c_, 1, 1)
379
+ self.cv2 = Conv(c1, c_, 1, 1)
380
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
381
+ self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
382
+
383
+ def forward(self, x):
384
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
385
+
386
+
387
+ class CSPBase(nn.Module):
388
+ # CSP Bottleneck with 3 convolutions
389
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
390
+ super().__init__()
391
+ c_ = int(c2 * e) # hidden channels
392
+ self.cv1 = Conv(c1, c_, 1, 1)
393
+ self.cv2 = Conv(c1, c_, 1, 1)
394
+ self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
395
+ self.m = nn.Sequential(*(BottleneckBase(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
396
+
397
+ def forward(self, x):
398
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
399
+
400
+
401
+ class SPP(nn.Module):
402
+ # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
403
+ def __init__(self, c1, c2, k=(5, 9, 13)):
404
+ super().__init__()
405
+ c_ = c1 // 2 # hidden channels
406
+ self.cv1 = Conv(c1, c_, 1, 1)
407
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
408
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
409
+
410
+ def forward(self, x):
411
+ x = self.cv1(x)
412
+ with warnings.catch_warnings():
413
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
414
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
415
+
416
+
417
+ class ASPP(torch.nn.Module):
418
+
419
+ def __init__(self, in_channels, out_channels):
420
+ super().__init__()
421
+ kernel_sizes = [1, 3, 3, 1]
422
+ dilations = [1, 3, 6, 1]
423
+ paddings = [0, 3, 6, 0]
424
+ self.aspp = torch.nn.ModuleList()
425
+ for aspp_idx in range(len(kernel_sizes)):
426
+ conv = torch.nn.Conv2d(
427
+ in_channels,
428
+ out_channels,
429
+ kernel_size=kernel_sizes[aspp_idx],
430
+ stride=1,
431
+ dilation=dilations[aspp_idx],
432
+ padding=paddings[aspp_idx],
433
+ bias=True)
434
+ self.aspp.append(conv)
435
+ self.gap = torch.nn.AdaptiveAvgPool2d(1)
436
+ self.aspp_num = len(kernel_sizes)
437
+ for m in self.modules():
438
+ if isinstance(m, torch.nn.Conv2d):
439
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
440
+ m.weight.data.normal_(0, math.sqrt(2. / n))
441
+ m.bias.data.fill_(0)
442
+
443
+ def forward(self, x):
444
+ avg_x = self.gap(x)
445
+ out = []
446
+ for aspp_idx in range(self.aspp_num):
447
+ inp = avg_x if (aspp_idx == self.aspp_num - 1) else x
448
+ out.append(F.relu_(self.aspp[aspp_idx](inp)))
449
+ out[-1] = out[-1].expand_as(out[-2])
450
+ out = torch.cat(out, dim=1)
451
+ return out
452
+
453
+
454
+ class SPPCSPC(nn.Module):
455
+ # CSP SPP https://github.com/WongKinYiu/CrossStagePartialNetworks
456
+ def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, k=(5, 9, 13)):
457
+ super(SPPCSPC, self).__init__()
458
+ c_ = int(2 * c2 * e) # hidden channels
459
+ self.cv1 = Conv(c1, c_, 1, 1)
460
+ self.cv2 = Conv(c1, c_, 1, 1)
461
+ self.cv3 = Conv(c_, c_, 3, 1)
462
+ self.cv4 = Conv(c_, c_, 1, 1)
463
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
464
+ self.cv5 = Conv(4 * c_, c_, 1, 1)
465
+ self.cv6 = Conv(c_, c_, 3, 1)
466
+ self.cv7 = Conv(2 * c_, c2, 1, 1)
467
+
468
+ def forward(self, x):
469
+ x1 = self.cv4(self.cv3(self.cv1(x)))
470
+ y1 = self.cv6(self.cv5(torch.cat([x1] + [m(x1) for m in self.m], 1)))
471
+ y2 = self.cv2(x)
472
+ return self.cv7(torch.cat((y1, y2), dim=1))
473
+
474
+
475
+ class SPPF(nn.Module):
476
+ # Spatial Pyramid Pooling - Fast (SPPF) layer by Glenn Jocher
477
+ def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
478
+ super().__init__()
479
+ c_ = c1 // 2 # hidden channels
480
+ self.cv1 = Conv(c1, c_, 1, 1)
481
+ self.cv2 = Conv(c_ * 4, c2, 1, 1)
482
+ self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
483
+ # self.m = SoftPool2d(kernel_size=k, stride=1, padding=k // 2)
484
+
485
+ def forward(self, x):
486
+ x = self.cv1(x)
487
+ with warnings.catch_warnings():
488
+ warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
489
+ y1 = self.m(x)
490
+ y2 = self.m(y1)
491
+ return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
492
+
493
+
494
+ import torch.nn.functional as F
495
+ from torch.nn.modules.utils import _pair
496
+
497
+
498
+ class ReOrg(nn.Module):
499
+ # yolo
500
+ def __init__(self):
501
+ super(ReOrg, self).__init__()
502
+
503
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
504
+ return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
505
+
506
+
507
+ class Contract(nn.Module):
508
+ # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
509
+ def __init__(self, gain=2):
510
+ super().__init__()
511
+ self.gain = gain
512
+
513
+ def forward(self, x):
514
+ b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
515
+ s = self.gain
516
+ x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
517
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
518
+ return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
519
+
520
+
521
+ class Expand(nn.Module):
522
+ # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
523
+ def __init__(self, gain=2):
524
+ super().__init__()
525
+ self.gain = gain
526
+
527
+ def forward(self, x):
528
+ b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
529
+ s = self.gain
530
+ x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
531
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
532
+ return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
533
+
534
+
535
+ class Concat(nn.Module):
536
+ # Concatenate a list of tensors along dimension
537
+ def __init__(self, dimension=1):
538
+ super().__init__()
539
+ self.d = dimension
540
+
541
+ def forward(self, x):
542
+ return torch.cat(x, self.d)
543
+
544
+
545
+ class Shortcut(nn.Module):
546
+ def __init__(self, dimension=0):
547
+ super(Shortcut, self).__init__()
548
+ self.d = dimension
549
+
550
+ def forward(self, x):
551
+ return x[0]+x[1]
552
+
553
+
554
+ class Silence(nn.Module):
555
+ def __init__(self):
556
+ super(Silence, self).__init__()
557
+ def forward(self, x):
558
+ return x
559
+
560
+
561
+ ##### GELAN #####
562
+
563
+ class SPPELAN(nn.Module):
564
+ # spp-elan
565
+ def __init__(self, c1, c2, c3): # ch_in, ch_out, number, shortcut, groups, expansion
566
+ super().__init__()
567
+ self.c = c3
568
+ self.cv1 = Conv(c1, c3, 1, 1)
569
+ self.cv2 = SP(5)
570
+ self.cv3 = SP(5)
571
+ self.cv4 = SP(5)
572
+ self.cv5 = Conv(4*c3, c2, 1, 1)
573
+
574
+ def forward(self, x):
575
+ y = [self.cv1(x)]
576
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
577
+ return self.cv5(torch.cat(y, 1))
578
+
579
+
580
+ class ELAN1(nn.Module):
581
+
582
+ def __init__(self, c1, c2, c3, c4): # ch_in, ch_out, number, shortcut, groups, expansion
583
+ super().__init__()
584
+ self.c = c3//2
585
+ self.cv1 = Conv(c1, c3, 1, 1)
586
+ self.cv2 = Conv(c3//2, c4, 3, 1)
587
+ self.cv3 = Conv(c4, c4, 3, 1)
588
+ self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
589
+
590
+ def forward(self, x):
591
+ y = list(self.cv1(x).chunk(2, 1))
592
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
593
+ return self.cv4(torch.cat(y, 1))
594
+
595
+ def forward_split(self, x):
596
+ y = list(self.cv1(x).split((self.c, self.c), 1))
597
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
598
+ return self.cv4(torch.cat(y, 1))
599
+
600
+
601
+ class RepNCSPELAN4(nn.Module):
602
+ # csp-elan
603
+ def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
604
+ super().__init__()
605
+ self.c = c3//2
606
+ self.cv1 = Conv(c1, c3, 1, 1)
607
+ self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), Conv(c4, c4, 3, 1))
608
+ self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
609
+ self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
610
+
611
+ def forward(self, x):
612
+ y = list(self.cv1(x).chunk(2, 1))
613
+ y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
614
+ return self.cv4(torch.cat(y, 1))
615
+
616
+ def forward_split(self, x):
617
+ y = list(self.cv1(x).split((self.c, self.c), 1))
618
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
619
+ return self.cv4(torch.cat(y, 1))
620
+
621
+ #################
622
+
623
+
624
+ ##### YOLOR #####
625
+
626
+ class ImplicitA(nn.Module):
627
+ def __init__(self, channel):
628
+ super(ImplicitA, self).__init__()
629
+ self.channel = channel
630
+ self.implicit = nn.Parameter(torch.zeros(1, channel, 1, 1))
631
+ nn.init.normal_(self.implicit, std=.02)
632
+
633
+ def forward(self, x):
634
+ return self.implicit + x
635
+
636
+
637
+ class ImplicitM(nn.Module):
638
+ def __init__(self, channel):
639
+ super(ImplicitM, self).__init__()
640
+ self.channel = channel
641
+ self.implicit = nn.Parameter(torch.ones(1, channel, 1, 1))
642
+ nn.init.normal_(self.implicit, mean=1., std=.02)
643
+
644
+ def forward(self, x):
645
+ return self.implicit * x
646
+
647
+ #################
648
+
649
+
650
+ ##### CBNet #####
651
+
652
+ class CBLinear(nn.Module):
653
+ def __init__(self, c1, c2s, k=1, s=1, p=None, g=1): # ch_in, ch_outs, kernel, stride, padding, groups
654
+ super(CBLinear, self).__init__()
655
+ self.c2s = c2s
656
+ self.conv = nn.Conv2d(c1, sum(c2s), k, s, autopad(k, p), groups=g, bias=True)
657
+
658
+ def forward(self, x):
659
+ outs = self.conv(x).split(self.c2s, dim=1)
660
+ return outs
661
+
662
+ class CBFuse(nn.Module):
663
+ def __init__(self, idx):
664
+ super(CBFuse, self).__init__()
665
+ self.idx = idx
666
+
667
+ def forward(self, xs):
668
+ target_size = xs[-1].shape[2:]
669
+ res = [F.interpolate(x[self.idx[i]], size=target_size, mode='nearest') for i, x in enumerate(xs[:-1])]
670
+ out = torch.sum(torch.stack(res + xs[-1:]), dim=0)
671
+ return out
672
+
673
+ #################
674
+
675
+
676
+ class DetectMultiBackend(nn.Module):
677
+ # YOLO MultiBackend class for python inference on various backends
678
+ def __init__(self, weights='yolo.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
679
+ # Usage:
680
+ # PyTorch: weights = *.pt
681
+ # TorchScript: *.torchscript
682
+ # ONNX Runtime: *.onnx
683
+ # ONNX OpenCV DNN: *.onnx --dnn
684
+ # OpenVINO: *_openvino_model
685
+ # CoreML: *.mlmodel
686
+ # TensorRT: *.engine
687
+ # TensorFlow SavedModel: *_saved_model
688
+ # TensorFlow GraphDef: *.pb
689
+ # TensorFlow Lite: *.tflite
690
+ # TensorFlow Edge TPU: *_edgetpu.tflite
691
+ # PaddlePaddle: *_paddle_model
692
+ from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
693
+
694
+ super().__init__()
695
+ w = str(weights[0] if isinstance(weights, list) else weights)
696
+ pt, jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
697
+ fp16 &= pt or jit or onnx or engine # FP16
698
+ nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
699
+ stride = 32 # default stride
700
+ cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
701
+ if not (pt or triton):
702
+ w = attempt_download(w) # download if not local
703
+
704
+ if pt: # PyTorch
705
+ model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
706
+ stride = max(int(model.stride.max()), 32) # model stride
707
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
708
+ model.half() if fp16 else model.float()
709
+ self.model = model # explicitly assign for to(), cpu(), cuda(), half()
710
+ elif jit: # TorchScript
711
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
712
+ extra_files = {'config.txt': ''} # model metadata
713
+ model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
714
+ model.half() if fp16 else model.float()
715
+ if extra_files['config.txt']: # load metadata dict
716
+ d = json.loads(extra_files['config.txt'],
717
+ object_hook=lambda d: {int(k) if k.isdigit() else k: v
718
+ for k, v in d.items()})
719
+ stride, names = int(d['stride']), d['names']
720
+ elif dnn: # ONNX OpenCV DNN
721
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
722
+ check_requirements('opencv-python>=4.5.4')
723
+ net = cv2.dnn.readNetFromONNX(w)
724
+ elif onnx: # ONNX Runtime
725
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
726
+ check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
727
+ import onnxruntime
728
+ providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
729
+ session = onnxruntime.InferenceSession(w, providers=providers)
730
+ output_names = [x.name for x in session.get_outputs()]
731
+ meta = session.get_modelmeta().custom_metadata_map # metadata
732
+ if 'stride' in meta:
733
+ stride, names = int(meta['stride']), eval(meta['names'])
734
+ elif xml: # OpenVINO
735
+ LOGGER.info(f'Loading {w} for OpenVINO inference...')
736
+ check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
737
+ from openvino.runtime import Core, Layout, get_batch
738
+ ie = Core()
739
+ if not Path(w).is_file(): # if not *.xml
740
+ w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
741
+ network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
742
+ if network.get_parameters()[0].get_layout().empty:
743
+ network.get_parameters()[0].set_layout(Layout("NCHW"))
744
+ batch_dim = get_batch(network)
745
+ if batch_dim.is_static:
746
+ batch_size = batch_dim.get_length()
747
+ executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
748
+ stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
749
+ elif engine: # TensorRT
750
+ LOGGER.info(f'Loading {w} for TensorRT inference...')
751
+ import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
752
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
753
+ if device.type == 'cpu':
754
+ device = torch.device('cuda:0')
755
+ Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
756
+ logger = trt.Logger(trt.Logger.INFO)
757
+ with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
758
+ model = runtime.deserialize_cuda_engine(f.read())
759
+ context = model.create_execution_context()
760
+ bindings = OrderedDict()
761
+ output_names = []
762
+ fp16 = False # default updated below
763
+ dynamic = False
764
+ for i in range(model.num_bindings):
765
+ name = model.get_binding_name(i)
766
+ dtype = trt.nptype(model.get_binding_dtype(i))
767
+ if model.binding_is_input(i):
768
+ if -1 in tuple(model.get_binding_shape(i)): # dynamic
769
+ dynamic = True
770
+ context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
771
+ if dtype == np.float16:
772
+ fp16 = True
773
+ else: # output
774
+ output_names.append(name)
775
+ shape = tuple(context.get_binding_shape(i))
776
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
777
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
778
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
779
+ batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
780
+ elif coreml: # CoreML
781
+ LOGGER.info(f'Loading {w} for CoreML inference...')
782
+ import coremltools as ct
783
+ model = ct.models.MLModel(w)
784
+ elif saved_model: # TF SavedModel
785
+ LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
786
+ import tensorflow as tf
787
+ keras = False # assume TF1 saved_model
788
+ model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
789
+ elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
790
+ LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
791
+ import tensorflow as tf
792
+
793
+ def wrap_frozen_graph(gd, inputs, outputs):
794
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
795
+ ge = x.graph.as_graph_element
796
+ return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
797
+
798
+ def gd_outputs(gd):
799
+ name_list, input_list = [], []
800
+ for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
801
+ name_list.append(node.name)
802
+ input_list.extend(node.input)
803
+ return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
804
+
805
+ gd = tf.Graph().as_graph_def() # TF GraphDef
806
+ with open(w, 'rb') as f:
807
+ gd.ParseFromString(f.read())
808
+ frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd))
809
+ elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
810
+ try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
811
+ from tflite_runtime.interpreter import Interpreter, load_delegate
812
+ except ImportError:
813
+ import tensorflow as tf
814
+ Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
815
+ if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
816
+ LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
817
+ delegate = {
818
+ 'Linux': 'libedgetpu.so.1',
819
+ 'Darwin': 'libedgetpu.1.dylib',
820
+ 'Windows': 'edgetpu.dll'}[platform.system()]
821
+ interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
822
+ else: # TFLite
823
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
824
+ interpreter = Interpreter(model_path=w) # load TFLite model
825
+ interpreter.allocate_tensors() # allocate
826
+ input_details = interpreter.get_input_details() # inputs
827
+ output_details = interpreter.get_output_details() # outputs
828
+ # load metadata
829
+ with contextlib.suppress(zipfile.BadZipFile):
830
+ with zipfile.ZipFile(w, "r") as model:
831
+ meta_file = model.namelist()[0]
832
+ meta = ast.literal_eval(model.read(meta_file).decode("utf-8"))
833
+ stride, names = int(meta['stride']), meta['names']
834
+ elif tfjs: # TF.js
835
+ raise NotImplementedError('ERROR: YOLO TF.js inference is not supported')
836
+ elif paddle: # PaddlePaddle
837
+ LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
838
+ check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
839
+ import paddle.inference as pdi
840
+ if not Path(w).is_file(): # if not *.pdmodel
841
+ w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
842
+ weights = Path(w).with_suffix('.pdiparams')
843
+ config = pdi.Config(str(w), str(weights))
844
+ if cuda:
845
+ config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
846
+ predictor = pdi.create_predictor(config)
847
+ input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
848
+ output_names = predictor.get_output_names()
849
+ elif triton: # NVIDIA Triton Inference Server
850
+ LOGGER.info(f'Using {w} as Triton Inference Server...')
851
+ check_requirements('tritonclient[all]')
852
+ from utils.triton import TritonRemoteModel
853
+ model = TritonRemoteModel(url=w)
854
+ nhwc = model.runtime.startswith("tensorflow")
855
+ else:
856
+ raise NotImplementedError(f'ERROR: {w} is not a supported format')
857
+
858
+ # class names
859
+ if 'names' not in locals():
860
+ names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
861
+ if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
862
+ names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
863
+
864
+ self.__dict__.update(locals()) # assign all variables to self
865
+
866
+ def forward(self, im, augment=False, visualize=False):
867
+ # YOLO MultiBackend inference
868
+ b, ch, h, w = im.shape # batch, channel, height, width
869
+ if self.fp16 and im.dtype != torch.float16:
870
+ im = im.half() # to FP16
871
+ if self.nhwc:
872
+ im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
873
+
874
+ if self.pt: # PyTorch
875
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
876
+ elif self.jit: # TorchScript
877
+ y = self.model(im)
878
+ elif self.dnn: # ONNX OpenCV DNN
879
+ im = im.cpu().numpy() # torch to numpy
880
+ self.net.setInput(im)
881
+ y = self.net.forward()
882
+ elif self.onnx: # ONNX Runtime
883
+ im = im.cpu().numpy() # torch to numpy
884
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
885
+ elif self.xml: # OpenVINO
886
+ im = im.cpu().numpy() # FP32
887
+ y = list(self.executable_network([im]).values())
888
+ elif self.engine: # TensorRT
889
+ if self.dynamic and im.shape != self.bindings['images'].shape:
890
+ i = self.model.get_binding_index('images')
891
+ self.context.set_binding_shape(i, im.shape) # reshape if dynamic
892
+ self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
893
+ for name in self.output_names:
894
+ i = self.model.get_binding_index(name)
895
+ self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
896
+ s = self.bindings['images'].shape
897
+ assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
898
+ self.binding_addrs['images'] = int(im.data_ptr())
899
+ self.context.execute_v2(list(self.binding_addrs.values()))
900
+ y = [self.bindings[x].data for x in sorted(self.output_names)]
901
+ elif self.coreml: # CoreML
902
+ im = im.cpu().numpy()
903
+ im = Image.fromarray((im[0] * 255).astype('uint8'))
904
+ # im = im.resize((192, 320), Image.ANTIALIAS)
905
+ y = self.model.predict({'image': im}) # coordinates are xywh normalized
906
+ if 'confidence' in y:
907
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
908
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
909
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
910
+ else:
911
+ y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
912
+ elif self.paddle: # PaddlePaddle
913
+ im = im.cpu().numpy().astype(np.float32)
914
+ self.input_handle.copy_from_cpu(im)
915
+ self.predictor.run()
916
+ y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
917
+ elif self.triton: # NVIDIA Triton Inference Server
918
+ y = self.model(im)
919
+ else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
920
+ im = im.cpu().numpy()
921
+ if self.saved_model: # SavedModel
922
+ y = self.model(im, training=False) if self.keras else self.model(im)
923
+ elif self.pb: # GraphDef
924
+ y = self.frozen_func(x=self.tf.constant(im))
925
+ else: # Lite or Edge TPU
926
+ input = self.input_details[0]
927
+ int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
928
+ if int8:
929
+ scale, zero_point = input['quantization']
930
+ im = (im / scale + zero_point).astype(np.uint8) # de-scale
931
+ self.interpreter.set_tensor(input['index'], im)
932
+ self.interpreter.invoke()
933
+ y = []
934
+ for output in self.output_details:
935
+ x = self.interpreter.get_tensor(output['index'])
936
+ if int8:
937
+ scale, zero_point = output['quantization']
938
+ x = (x.astype(np.float32) - zero_point) * scale # re-scale
939
+ y.append(x)
940
+ y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
941
+ y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
942
+
943
+ if isinstance(y, (list, tuple)):
944
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
945
+ else:
946
+ return self.from_numpy(y)
947
+
948
+ def from_numpy(self, x):
949
+ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
950
+
951
+ def warmup(self, imgsz=(1, 3, 640, 640)):
952
+ # Warmup model by running inference once
953
+ warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton
954
+ if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
955
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
956
+ for _ in range(2 if self.jit else 1): #
957
+ self.forward(im) # warmup
958
+
959
+ @staticmethod
960
+ def _model_type(p='path/to/model.pt'):
961
+ # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
962
+ # types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
963
+ from export import export_formats
964
+ from utils.downloads import is_url
965
+ sf = list(export_formats().Suffix) # export suffixes
966
+ if not is_url(p, check=False):
967
+ check_suffix(p, sf) # checks
968
+ url = urlparse(p) # if url may be Triton inference server
969
+ types = [s in Path(p).name for s in sf]
970
+ types[8] &= not types[9] # tflite &= not edgetpu
971
+ triton = not any(types) and all([any(s in url.scheme for s in ["http", "grpc"]), url.netloc])
972
+ return types + [triton]
973
+
974
+ @staticmethod
975
+ def _load_metadata(f=Path('path/to/meta.yaml')):
976
+ # Load metadata from meta.yaml if it exists
977
+ if f.exists():
978
+ d = yaml_load(f)
979
+ return d['stride'], d['names'] # assign stride, names
980
+ return None, None
981
+
982
+
983
+ class AutoShape(nn.Module):
984
+ # YOLO input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
985
+ conf = 0.25 # NMS confidence threshold
986
+ iou = 0.45 # NMS IoU threshold
987
+ agnostic = False # NMS class-agnostic
988
+ multi_label = False # NMS multiple labels per box
989
+ classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
990
+ max_det = 1000 # maximum number of detections per image
991
+ amp = False # Automatic Mixed Precision (AMP) inference
992
+
993
+ def __init__(self, model, verbose=True):
994
+ super().__init__()
995
+ if verbose:
996
+ LOGGER.info('Adding AutoShape... ')
997
+ copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
998
+ self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
999
+ self.pt = not self.dmb or model.pt # PyTorch model
1000
+ self.model = model.eval()
1001
+ if self.pt:
1002
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
1003
+ m.inplace = False # Detect.inplace=False for safe multithread inference
1004
+ m.export = True # do not output loss values
1005
+
1006
+ def _apply(self, fn):
1007
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
1008
+ self = super()._apply(fn)
1009
+ from models.yolo import Detect, Segment
1010
+ if self.pt:
1011
+ m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
1012
+ if isinstance(m, (Detect, Segment)):
1013
+ for k in 'stride', 'anchor_grid', 'stride_grid', 'grid':
1014
+ x = getattr(m, k)
1015
+ setattr(m, k, list(map(fn, x))) if isinstance(x, (list, tuple)) else setattr(m, k, fn(x))
1016
+ return self
1017
+
1018
+ @smart_inference_mode()
1019
+ def forward(self, ims, size=640, augment=False, profile=False):
1020
+ # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
1021
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
1022
+ # URI: = 'https://ultralytics.com/images/zidane.jpg'
1023
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
1024
+ # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
1025
+ # numpy: = np.zeros((640,1280,3)) # HWC
1026
+ # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
1027
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
1028
+
1029
+ dt = (Profile(), Profile(), Profile())
1030
+ with dt[0]:
1031
+ if isinstance(size, int): # expand
1032
+ size = (size, size)
1033
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
1034
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
1035
+ if isinstance(ims, torch.Tensor): # torch
1036
+ with amp.autocast(autocast):
1037
+ return self.model(ims.to(p.device).type_as(p), augment=augment) # inference
1038
+
1039
+ # Pre-process
1040
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
1041
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
1042
+ for i, im in enumerate(ims):
1043
+ f = f'image{i}' # filename
1044
+ if isinstance(im, (str, Path)): # filename or uri
1045
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
1046
+ im = np.asarray(exif_transpose(im))
1047
+ elif isinstance(im, Image.Image): # PIL Image
1048
+ im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
1049
+ files.append(Path(f).with_suffix('.jpg').name)
1050
+ if im.shape[0] < 5: # image in CHW
1051
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
1052
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
1053
+ s = im.shape[:2] # HWC
1054
+ shape0.append(s) # image shape
1055
+ g = max(size) / max(s) # gain
1056
+ shape1.append([int(y * g) for y in s])
1057
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
1058
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] # inf shape
1059
+ x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
1060
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
1061
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
1062
+
1063
+ with amp.autocast(autocast):
1064
+ # Inference
1065
+ with dt[1]:
1066
+ y = self.model(x, augment=augment) # forward
1067
+
1068
+ # Post-process
1069
+ with dt[2]:
1070
+ y = non_max_suppression(y if self.dmb else y[0],
1071
+ self.conf,
1072
+ self.iou,
1073
+ self.classes,
1074
+ self.agnostic,
1075
+ self.multi_label,
1076
+ max_det=self.max_det) # NMS
1077
+ for i in range(n):
1078
+ scale_boxes(shape1, y[i][:, :4], shape0[i])
1079
+
1080
+ return Detections(ims, y, files, dt, self.names, x.shape)
1081
+
1082
+
1083
+ class Detections:
1084
+ # YOLO detections class for inference results
1085
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
1086
+ super().__init__()
1087
+ d = pred[0].device # device
1088
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
1089
+ self.ims = ims # list of images as numpy arrays
1090
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
1091
+ self.names = names # class names
1092
+ self.files = files # image filenames
1093
+ self.times = times # profiling times
1094
+ self.xyxy = pred # xyxy pixels
1095
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
1096
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
1097
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
1098
+ self.n = len(self.pred) # number of images (batch size)
1099
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
1100
+ self.s = tuple(shape) # inference BCHW shape
1101
+
1102
+ def _run(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
1103
+ s, crops = '', []
1104
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
1105
+ s += f'\nimage {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
1106
+ if pred.shape[0]:
1107
+ for c in pred[:, -1].unique():
1108
+ n = (pred[:, -1] == c).sum() # detections per class
1109
+ s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
1110
+ s = s.rstrip(', ')
1111
+ if show or save or render or crop:
1112
+ annotator = Annotator(im, example=str(self.names))
1113
+ for *box, conf, cls in reversed(pred): # xyxy, confidence, class
1114
+ label = f'{self.names[int(cls)]} {conf:.2f}'
1115
+ if crop:
1116
+ file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
1117
+ crops.append({
1118
+ 'box': box,
1119
+ 'conf': conf,
1120
+ 'cls': cls,
1121
+ 'label': label,
1122
+ 'im': save_one_box(box, im, file=file, save=save)})
1123
+ else: # all others
1124
+ annotator.box_label(box, label if labels else '', color=colors(cls))
1125
+ im = annotator.im
1126
+ else:
1127
+ s += '(no detections)'
1128
+
1129
+ im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
1130
+ if show:
1131
+ display(im) if is_notebook() else im.show(self.files[i])
1132
+ if save:
1133
+ f = self.files[i]
1134
+ im.save(save_dir / f) # save
1135
+ if i == self.n - 1:
1136
+ LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
1137
+ if render:
1138
+ self.ims[i] = np.asarray(im)
1139
+ if pprint:
1140
+ s = s.lstrip('\n')
1141
+ return f'{s}\nSpeed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {self.s}' % self.t
1142
+ if crop:
1143
+ if save:
1144
+ LOGGER.info(f'Saved results to {save_dir}\n')
1145
+ return crops
1146
+
1147
+ @TryExcept('Showing images is not supported in this environment')
1148
+ def show(self, labels=True):
1149
+ self._run(show=True, labels=labels) # show results
1150
+
1151
+ def save(self, labels=True, save_dir='runs/detect/exp', exist_ok=False):
1152
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) # increment save_dir
1153
+ self._run(save=True, labels=labels, save_dir=save_dir) # save results
1154
+
1155
+ def crop(self, save=True, save_dir='runs/detect/exp', exist_ok=False):
1156
+ save_dir = increment_path(save_dir, exist_ok, mkdir=True) if save else None
1157
+ return self._run(crop=True, save=save, save_dir=save_dir) # crop results
1158
+
1159
+ def render(self, labels=True):
1160
+ self._run(render=True, labels=labels) # render results
1161
+ return self.ims
1162
+
1163
+ def pandas(self):
1164
+ # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
1165
+ new = copy(self) # return copy
1166
+ ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
1167
+ cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
1168
+ for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
1169
+ a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
1170
+ setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
1171
+ return new
1172
+
1173
+ def tolist(self):
1174
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
1175
+ r = range(self.n) # iterable
1176
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
1177
+ # for d in x:
1178
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
1179
+ # setattr(d, k, getattr(d, k)[0]) # pop out of list
1180
+ return x
1181
+
1182
+ def print(self):
1183
+ LOGGER.info(self.__str__())
1184
+
1185
+ def __len__(self): # override len(results)
1186
+ return self.n
1187
+
1188
+ def __str__(self): # override print(results)
1189
+ return self._run(pprint=True) # print results
1190
+
1191
+ def __repr__(self):
1192
+ return f'YOLO {self.__class__} instance\n' + self.__str__()
1193
+
1194
+
1195
+ class Proto(nn.Module):
1196
+ # YOLO mask Proto module for segmentation models
1197
+ def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks
1198
+ super().__init__()
1199
+ self.cv1 = Conv(c1, c_, k=3)
1200
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
1201
+ self.cv2 = Conv(c_, c_, k=3)
1202
+ self.cv3 = Conv(c_, c2)
1203
+
1204
+ def forward(self, x):
1205
+ return self.cv3(self.cv2(self.upsample(self.cv1(x))))
1206
+
1207
+
1208
+ class UConv(nn.Module):
1209
+ def __init__(self, c1, c_=256, c2=256): # ch_in, number of protos, number of masks
1210
+ super().__init__()
1211
+
1212
+ self.cv1 = Conv(c1, c_, k=3)
1213
+ self.cv2 = nn.Conv2d(c_, c2, 1, 1)
1214
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
1215
+
1216
+ def forward(self, x):
1217
+ return self.up(self.cv2(self.cv1(x)))
1218
+
1219
+
1220
+ class Classify(nn.Module):
1221
+ # YOLO classification head, i.e. x(b,c1,20,20) to x(b,c2)
1222
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
1223
+ super().__init__()
1224
+ c_ = 1280 # efficientnet_b0 size
1225
+ self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
1226
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
1227
+ self.drop = nn.Dropout(p=0.0, inplace=True)
1228
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
1229
+
1230
+ def forward(self, x):
1231
+ if isinstance(x, list):
1232
+ x = torch.cat(x, 1)
1233
+ return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
yolov9/models/detect/gelan-c.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, ADown, [256]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, ADown, [512]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, ADown, [512]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, ADown, [256]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, ADown, [512]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # DDetect(P3, P4, P5)
80
+ ]
yolov9/models/detect/gelan-e.yaml ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [1024]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 9
44
+
45
+ # routing
46
+ [1, 1, CBLinear, [[64]]], # 10
47
+ [3, 1, CBLinear, [[64, 128]]], # 11
48
+ [5, 1, CBLinear, [[64, 128, 256]]], # 12
49
+ [7, 1, CBLinear, [[64, 128, 256, 512]]], # 13
50
+ [9, 1, CBLinear, [[64, 128, 256, 512, 1024]]], # 14
51
+
52
+ # conv down fuse
53
+ [0, 1, Conv, [64, 3, 2]], # 15-P1/2
54
+ [[10, 11, 12, 13, 14, -1], 1, CBFuse, [[0, 0, 0, 0, 0]]], # 16
55
+
56
+ # conv down fuse
57
+ [-1, 1, Conv, [128, 3, 2]], # 17-P2/4
58
+ [[11, 12, 13, 14, -1], 1, CBFuse, [[1, 1, 1, 1]]], # 18
59
+
60
+ # elan-1 block
61
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 2]], # 19
62
+
63
+ # avg-conv down fuse
64
+ [-1, 1, ADown, [256]], # 20-P3/8
65
+ [[12, 13, 14, -1], 1, CBFuse, [[2, 2, 2]]], # 21
66
+
67
+ # elan-2 block
68
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 22
69
+
70
+ # avg-conv down fuse
71
+ [-1, 1, ADown, [512]], # 23-P4/16
72
+ [[13, 14, -1], 1, CBFuse, [[3, 3]]], # 24
73
+
74
+ # elan-2 block
75
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 25
76
+
77
+ # avg-conv down fuse
78
+ [-1, 1, ADown, [1024]], # 26-P5/32
79
+ [[14, -1], 1, CBFuse, [[4]]], # 27
80
+
81
+ # elan-2 block
82
+ [-1, 1, RepNCSPELAN4, [1024, 512, 256, 2]], # 28
83
+ ]
84
+
85
+ # gelan head
86
+ head:
87
+ [
88
+ # elan-spp block
89
+ [28, 1, SPPELAN, [512, 256]], # 29
90
+
91
+ # up-concat merge
92
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
93
+ [[-1, 25], 1, Concat, [1]], # cat backbone P4
94
+
95
+ # elan-2 block
96
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 32
97
+
98
+ # up-concat merge
99
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
100
+ [[-1, 22], 1, Concat, [1]], # cat backbone P3
101
+
102
+ # elan-2 block
103
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 35 (P3/8-small)
104
+
105
+ # avg-conv-down merge
106
+ [-1, 1, ADown, [256]],
107
+ [[-1, 32], 1, Concat, [1]], # cat head P4
108
+
109
+ # elan-2 block
110
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 2]], # 38 (P4/16-medium)
111
+
112
+ # avg-conv-down merge
113
+ [-1, 1, ADown, [512]],
114
+ [[-1, 29], 1, Concat, [1]], # cat head P5
115
+
116
+ # elan-2 block
117
+ [-1, 1, RepNCSPELAN4, [512, 1024, 512, 2]], # 41 (P5/32-large)
118
+
119
+ # detect
120
+ [[35, 38, 41], 1, DDetect, [nc]], # Detect(P3, P4, P5)
121
+ ]
yolov9/models/detect/gelan-m.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [32, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [64, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, AConv, [240]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, AConv, [360]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, AConv, [480]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]], # 8
42
+ ]
43
+
44
+ # elan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [480, 240]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]], # 15
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, AConv, [180]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, AConv, [240]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # DDetect(P3, P4, P5)
80
+ ]
yolov9/models/detect/gelan-s.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [32, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [64, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, ELAN1, [64, 64, 32]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, AConv, [128]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, AConv, [192]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, AConv, [256]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]], # 8
42
+ ]
43
+
44
+ # elan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [256, 128]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 15
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, AConv, [96]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, AConv, [128]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # DDetect(P3, P4, P5)
80
+ ]
yolov9/models/detect/gelan-t.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [16, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [32, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, ELAN1, [32, 32, 16]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, AConv, [64]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, AConv, [96]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, AConv, [128]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 8
42
+ ]
43
+
44
+ # elan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [128, 64]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]], # 15
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, AConv, [48]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, AConv, [64]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # DDetect(P3, P4, P5)
80
+ ]
yolov9/models/detect/gelan.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, Conv, [512, 3, 2]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, Conv, [256, 3, 2]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, Conv, [512, 3, 2]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ # detect
79
+ [[15, 18, 21], 1, DDetect, [nc]], # Detect(P3, P4, P5)
80
+ ]
yolov9/models/detect/yolov7-af.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv7
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1. # model depth multiple
6
+ width_multiple: 1. # layer channel multiple
7
+ anchors: 3
8
+
9
+ # YOLOv7 backbone
10
+ backbone:
11
+ # [from, number, module, args]
12
+ [[-1, 1, Conv, [32, 3, 1]], # 0
13
+
14
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
15
+ [-1, 1, Conv, [64, 3, 1]],
16
+
17
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
18
+ [-1, 1, Conv, [64, 1, 1]],
19
+ [-2, 1, Conv, [64, 1, 1]],
20
+ [-1, 1, Conv, [64, 3, 1]],
21
+ [-1, 1, Conv, [64, 3, 1]],
22
+ [-1, 1, Conv, [64, 3, 1]],
23
+ [-1, 1, Conv, [64, 3, 1]],
24
+ [[-1, -3, -5, -6], 1, Concat, [1]],
25
+ [-1, 1, Conv, [256, 1, 1]], # 11
26
+
27
+ [-1, 1, MP, []],
28
+ [-1, 1, Conv, [128, 1, 1]],
29
+ [-3, 1, Conv, [128, 1, 1]],
30
+ [-1, 1, Conv, [128, 3, 2]],
31
+ [[-1, -3], 1, Concat, [1]], # 16-P3/8
32
+ [-1, 1, Conv, [128, 1, 1]],
33
+ [-2, 1, Conv, [128, 1, 1]],
34
+ [-1, 1, Conv, [128, 3, 1]],
35
+ [-1, 1, Conv, [128, 3, 1]],
36
+ [-1, 1, Conv, [128, 3, 1]],
37
+ [-1, 1, Conv, [128, 3, 1]],
38
+ [[-1, -3, -5, -6], 1, Concat, [1]],
39
+ [-1, 1, Conv, [512, 1, 1]], # 24
40
+
41
+ [-1, 1, MP, []],
42
+ [-1, 1, Conv, [256, 1, 1]],
43
+ [-3, 1, Conv, [256, 1, 1]],
44
+ [-1, 1, Conv, [256, 3, 2]],
45
+ [[-1, -3], 1, Concat, [1]], # 29-P4/16
46
+ [-1, 1, Conv, [256, 1, 1]],
47
+ [-2, 1, Conv, [256, 1, 1]],
48
+ [-1, 1, Conv, [256, 3, 1]],
49
+ [-1, 1, Conv, [256, 3, 1]],
50
+ [-1, 1, Conv, [256, 3, 1]],
51
+ [-1, 1, Conv, [256, 3, 1]],
52
+ [[-1, -3, -5, -6], 1, Concat, [1]],
53
+ [-1, 1, Conv, [1024, 1, 1]], # 37
54
+
55
+ [-1, 1, MP, []],
56
+ [-1, 1, Conv, [512, 1, 1]],
57
+ [-3, 1, Conv, [512, 1, 1]],
58
+ [-1, 1, Conv, [512, 3, 2]],
59
+ [[-1, -3], 1, Concat, [1]], # 42-P5/32
60
+ [-1, 1, Conv, [256, 1, 1]],
61
+ [-2, 1, Conv, [256, 1, 1]],
62
+ [-1, 1, Conv, [256, 3, 1]],
63
+ [-1, 1, Conv, [256, 3, 1]],
64
+ [-1, 1, Conv, [256, 3, 1]],
65
+ [-1, 1, Conv, [256, 3, 1]],
66
+ [[-1, -3, -5, -6], 1, Concat, [1]],
67
+ [-1, 1, Conv, [1024, 1, 1]], # 50
68
+ ]
69
+
70
+ # yolov7 head
71
+ head:
72
+ [[-1, 1, SPPCSPC, [512]], # 51
73
+
74
+ [-1, 1, Conv, [256, 1, 1]],
75
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
76
+ [37, 1, Conv, [256, 1, 1]], # route backbone P4
77
+ [[-1, -2], 1, Concat, [1]],
78
+
79
+ [-1, 1, Conv, [256, 1, 1]],
80
+ [-2, 1, Conv, [256, 1, 1]],
81
+ [-1, 1, Conv, [128, 3, 1]],
82
+ [-1, 1, Conv, [128, 3, 1]],
83
+ [-1, 1, Conv, [128, 3, 1]],
84
+ [-1, 1, Conv, [128, 3, 1]],
85
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
86
+ [-1, 1, Conv, [256, 1, 1]], # 63
87
+
88
+ [-1, 1, Conv, [128, 1, 1]],
89
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
90
+ [24, 1, Conv, [128, 1, 1]], # route backbone P3
91
+ [[-1, -2], 1, Concat, [1]],
92
+
93
+ [-1, 1, Conv, [128, 1, 1]],
94
+ [-2, 1, Conv, [128, 1, 1]],
95
+ [-1, 1, Conv, [64, 3, 1]],
96
+ [-1, 1, Conv, [64, 3, 1]],
97
+ [-1, 1, Conv, [64, 3, 1]],
98
+ [-1, 1, Conv, [64, 3, 1]],
99
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
100
+ [-1, 1, Conv, [128, 1, 1]], # 75
101
+
102
+ [-1, 1, MP, []],
103
+ [-1, 1, Conv, [128, 1, 1]],
104
+ [-3, 1, Conv, [128, 1, 1]],
105
+ [-1, 1, Conv, [128, 3, 2]],
106
+ [[-1, -3, 63], 1, Concat, [1]],
107
+
108
+ [-1, 1, Conv, [256, 1, 1]],
109
+ [-2, 1, Conv, [256, 1, 1]],
110
+ [-1, 1, Conv, [128, 3, 1]],
111
+ [-1, 1, Conv, [128, 3, 1]],
112
+ [-1, 1, Conv, [128, 3, 1]],
113
+ [-1, 1, Conv, [128, 3, 1]],
114
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
115
+ [-1, 1, Conv, [256, 1, 1]], # 88
116
+
117
+ [-1, 1, MP, []],
118
+ [-1, 1, Conv, [256, 1, 1]],
119
+ [-3, 1, Conv, [256, 1, 1]],
120
+ [-1, 1, Conv, [256, 3, 2]],
121
+ [[-1, -3, 51], 1, Concat, [1]],
122
+
123
+ [-1, 1, Conv, [512, 1, 1]],
124
+ [-2, 1, Conv, [512, 1, 1]],
125
+ [-1, 1, Conv, [256, 3, 1]],
126
+ [-1, 1, Conv, [256, 3, 1]],
127
+ [-1, 1, Conv, [256, 3, 1]],
128
+ [-1, 1, Conv, [256, 3, 1]],
129
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
130
+ [-1, 1, Conv, [512, 1, 1]], # 101
131
+
132
+ [75, 1, Conv, [256, 3, 1]],
133
+ [88, 1, Conv, [512, 3, 1]],
134
+ [101, 1, Conv, [1024, 3, 1]],
135
+
136
+ [[102, 103, 104], 1, Detect, [nc]], # Detect(P3, P4, P5)
137
+ ]
yolov9/models/detect/yolov9-cf.yaml ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [512]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
44
+ ]
45
+
46
+ # YOLOv9 head
47
+ head:
48
+ [
49
+ # elan-spp block
50
+ [-1, 1, SPPELAN, [512, 256]], # 10
51
+
52
+ # up-concat merge
53
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
54
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
55
+
56
+ # elan-2 block
57
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
58
+
59
+ # up-concat merge
60
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
61
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
62
+
63
+ # elan-2 block
64
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
65
+
66
+ # avg-conv-down merge
67
+ [-1, 1, ADown, [256]],
68
+ [[-1, 13], 1, Concat, [1]], # cat head P4
69
+
70
+ # elan-2 block
71
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
72
+
73
+ # avg-conv-down merge
74
+ [-1, 1, ADown, [512]],
75
+ [[-1, 10], 1, Concat, [1]], # cat head P5
76
+
77
+ # elan-2 block
78
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
79
+
80
+
81
+ # multi-level reversible auxiliary branch
82
+
83
+ # routing
84
+ [5, 1, CBLinear, [[256]]], # 23
85
+ [7, 1, CBLinear, [[256, 512]]], # 24
86
+ [9, 1, CBLinear, [[256, 512, 512]]], # 25
87
+
88
+ # conv down
89
+ [0, 1, Conv, [64, 3, 2]], # 26-P1/2
90
+
91
+ # conv down
92
+ [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
93
+
94
+ # elan-1 block
95
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
96
+
97
+ # avg-conv down fuse
98
+ [-1, 1, ADown, [256]], # 29-P3/8
99
+ [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
100
+
101
+ # elan-2 block
102
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
103
+
104
+ # avg-conv down fuse
105
+ [-1, 1, ADown, [512]], # 32-P4/16
106
+ [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
107
+
108
+ # elan-2 block
109
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
110
+
111
+ # avg-conv down fuse
112
+ [-1, 1, ADown, [512]], # 35-P5/32
113
+ [[25, -1], 1, CBFuse, [[2]]], # 36
114
+
115
+ # elan-2 block
116
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
117
+
118
+
119
+
120
+ # detection head
121
+
122
+ # detect
123
+ [[31, 34, 37, 16, 19, 22, 16, 19, 22], 1, TripleDDetect, [nc]], # TripleDDetect(A3, A4, A5, P3, P4, P5, P3, P4, P5) Auxiliary/Coarse(NMS-based)/Fine(NMS-free)
124
+ ]
yolov9/models/detect/yolov9-m.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [32, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [64, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 1]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, AConv, [240]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, AConv, [360]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, AConv, [480]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]], # 9
44
+ ]
45
+
46
+ # elan head
47
+ head:
48
+ [
49
+ # elan-spp block
50
+ [-1, 1, SPPELAN, [480, 240]], # 10
51
+
52
+ # up-concat merge
53
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
54
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
55
+
56
+ # elan-2 block
57
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 13
58
+
59
+ # up-concat merge
60
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
61
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
62
+
63
+ # elan-2 block
64
+ [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]], # 16
65
+
66
+ # avg-conv-down merge
67
+ [-1, 1, AConv, [180]],
68
+ [[-1, 13], 1, Concat, [1]], # cat head P4
69
+
70
+ # elan-2 block
71
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 19 (P4/16-medium)
72
+
73
+ # avg-conv-down merge
74
+ [-1, 1, AConv, [240]],
75
+ [[-1, 10], 1, Concat, [1]], # cat head P5
76
+
77
+ # elan-2 block
78
+ [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]], # 22 (P5/32-large)
79
+
80
+ # routing
81
+ [5, 1, CBLinear, [[240]]], # 23
82
+ [7, 1, CBLinear, [[240, 360]]], # 24
83
+ [9, 1, CBLinear, [[240, 360, 480]]], # 25
84
+
85
+ # conv down
86
+ [0, 1, Conv, [32, 3, 2]], # 26-P1/2
87
+
88
+ # conv down
89
+ [-1, 1, Conv, [64, 3, 2]], # 27-P2/4
90
+
91
+ # elan-1 block
92
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 1]], # 28
93
+
94
+ # avg-conv down
95
+ [-1, 1, AConv, [240]], # 29-P3/8
96
+ [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
97
+
98
+ # elan-2 block
99
+ [-1, 1, RepNCSPELAN4, [240, 240, 120, 1]], # 31
100
+
101
+ # avg-conv down
102
+ [-1, 1, AConv, [360]], # 32-P4/16
103
+ [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
104
+
105
+ # elan-2 block
106
+ [-1, 1, RepNCSPELAN4, [360, 360, 180, 1]], # 34
107
+
108
+ # avg-conv down
109
+ [-1, 1, AConv, [480]], # 35-P5/32
110
+ [[25, -1], 1, CBFuse, [[2]]], # 36
111
+
112
+ # elan-2 block
113
+ [-1, 1, RepNCSPELAN4, [480, 480, 240, 1]], # 37
114
+
115
+ # detect
116
+ [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # Detect(P3, P4, P5)
117
+ ]
yolov9/models/detect/yolov9-s.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [32, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [64, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, ELAN1, [64, 64, 32]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, AConv, [128]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, AConv, [192]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, AConv, [256]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]], # 8
42
+ ]
43
+
44
+ # elan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [256, 128]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 15
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, AConv, [96]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, AConv, [128]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 3]], # 21 (P5/32-large)
77
+
78
+ # elan-spp block
79
+ [8, 1, SPPELAN, [256, 128]], # 22
80
+
81
+ # up-concat merge
82
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
83
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
84
+
85
+ # elan-2 block
86
+ [-1, 1, RepNCSPELAN4, [192, 192, 96, 3]], # 25
87
+
88
+ # up-concat merge
89
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
90
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
91
+
92
+ # elan-2 block
93
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 28
94
+
95
+ # detect
96
+ [[28, 25, 22, 15, 18, 21], 1, DualDDetect, [nc]], # Detect(P3, P4, P5)
97
+ ]
yolov9/models/detect/yolov9-t.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [16, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [32, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, ELAN1, [32, 32, 16]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, AConv, [64]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, AConv, [96]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, AConv, [128]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 8
42
+ ]
43
+
44
+ # elan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [128, 64]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]], # 15
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, AConv, [48]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, AConv, [64]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [128, 128, 64, 3]], # 21 (P5/32-large)
77
+
78
+ # elan-spp block
79
+ [8, 1, SPPELAN, [128, 64]], # 22
80
+
81
+ # up-concat merge
82
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
83
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
84
+
85
+ # elan-2 block
86
+ [-1, 1, RepNCSPELAN4, [96, 96, 48, 3]], # 25
87
+
88
+ # up-concat merge
89
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
90
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
91
+
92
+ # elan-2 block
93
+ [-1, 1, RepNCSPELAN4, [64, 64, 32, 3]], # 28
94
+
95
+ # detect
96
+ [[28, 25, 22, 15, 18, 21], 1, DualDDetect, [nc]], # Detect(P3, P4, P5)
97
+ ]
yolov9/models/detect/yolov9.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # YOLOv9 backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
26
+
27
+ # conv down
28
+ [-1, 1, Conv, [256, 3, 2]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
32
+
33
+ # conv down
34
+ [-1, 1, Conv, [512, 3, 2]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
38
+
39
+ # conv down
40
+ [-1, 1, Conv, [512, 3, 2]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
44
+ ]
45
+
46
+ # YOLOv9 head
47
+ head:
48
+ [
49
+ # elan-spp block
50
+ [-1, 1, SPPELAN, [512, 256]], # 10
51
+
52
+ # up-concat merge
53
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
54
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
55
+
56
+ # elan-2 block
57
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
58
+
59
+ # up-concat merge
60
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
61
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
62
+
63
+ # elan-2 block
64
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
65
+
66
+ # conv-down merge
67
+ [-1, 1, Conv, [256, 3, 2]],
68
+ [[-1, 13], 1, Concat, [1]], # cat head P4
69
+
70
+ # elan-2 block
71
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
72
+
73
+ # conv-down merge
74
+ [-1, 1, Conv, [512, 3, 2]],
75
+ [[-1, 10], 1, Concat, [1]], # cat head P5
76
+
77
+ # elan-2 block
78
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
79
+
80
+ # routing
81
+ [5, 1, CBLinear, [[256]]], # 23
82
+ [7, 1, CBLinear, [[256, 512]]], # 24
83
+ [9, 1, CBLinear, [[256, 512, 512]]], # 25
84
+
85
+ # conv down
86
+ [0, 1, Conv, [64, 3, 2]], # 26-P1/2
87
+
88
+ # conv down
89
+ [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
90
+
91
+ # elan-1 block
92
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
93
+
94
+ # conv down fuse
95
+ [-1, 1, Conv, [256, 3, 2]], # 29-P3/8
96
+ [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
97
+
98
+ # elan-2 block
99
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
100
+
101
+ # conv down fuse
102
+ [-1, 1, Conv, [512, 3, 2]], # 32-P4/16
103
+ [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
104
+
105
+ # elan-2 block
106
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
107
+
108
+ # conv down fuse
109
+ [-1, 1, Conv, [512, 3, 2]], # 35-P5/32
110
+ [[25, -1], 1, CBFuse, [[2]]], # 36
111
+
112
+ # elan-2 block
113
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
114
+
115
+ # detect
116
+ [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
117
+ ]
yolov9/models/experimental.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from utils.downloads import attempt_download
8
+
9
+
10
+ class Sum(nn.Module):
11
+ # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070
12
+ def __init__(self, n, weight=False): # n: number of inputs
13
+ super().__init__()
14
+ self.weight = weight # apply weights boolean
15
+ self.iter = range(n - 1) # iter object
16
+ if weight:
17
+ self.w = nn.Parameter(-torch.arange(1.0, n) / 2, requires_grad=True) # layer weights
18
+
19
+ def forward(self, x):
20
+ y = x[0] # no weight
21
+ if self.weight:
22
+ w = torch.sigmoid(self.w) * 2
23
+ for i in self.iter:
24
+ y = y + x[i + 1] * w[i]
25
+ else:
26
+ for i in self.iter:
27
+ y = y + x[i + 1]
28
+ return y
29
+
30
+
31
+ class MixConv2d(nn.Module):
32
+ # Mixed Depth-wise Conv https://arxiv.org/abs/1907.09595
33
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): # ch_in, ch_out, kernel, stride, ch_strategy
34
+ super().__init__()
35
+ n = len(k) # number of convolutions
36
+ if equal_ch: # equal c_ per group
37
+ i = torch.linspace(0, n - 1E-6, c2).floor() # c2 indices
38
+ c_ = [(i == g).sum() for g in range(n)] # intermediate channels
39
+ else: # equal weight.numel() per group
40
+ b = [c2] + [0] * n
41
+ a = np.eye(n + 1, n, k=-1)
42
+ a -= np.roll(a, 1, axis=1)
43
+ a *= np.array(k) ** 2
44
+ a[0] = 1
45
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
46
+
47
+ self.m = nn.ModuleList([
48
+ nn.Conv2d(c1, int(c_), k, s, k // 2, groups=math.gcd(c1, int(c_)), bias=False) for k, c_ in zip(k, c_)])
49
+ self.bn = nn.BatchNorm2d(c2)
50
+ self.act = nn.SiLU()
51
+
52
+ def forward(self, x):
53
+ return self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
54
+
55
+
56
+ class Ensemble(nn.ModuleList):
57
+ # Ensemble of models
58
+ def __init__(self):
59
+ super().__init__()
60
+
61
+ def forward(self, x, augment=False, profile=False, visualize=False):
62
+ y = [module(x, augment, profile, visualize)[0] for module in self]
63
+ # y = torch.stack(y).max(0)[0] # max ensemble
64
+ # y = torch.stack(y).mean(0) # mean ensemble
65
+ y = torch.cat(y, 1) # nms ensemble
66
+ return y, None # inference, train output
67
+
68
+
69
+ class ORT_NMS(torch.autograd.Function):
70
+ '''ONNX-Runtime NMS operation'''
71
+ @staticmethod
72
+ def forward(ctx,
73
+ boxes,
74
+ scores,
75
+ max_output_boxes_per_class=torch.tensor([100]),
76
+ iou_threshold=torch.tensor([0.45]),
77
+ score_threshold=torch.tensor([0.25])):
78
+ device = boxes.device
79
+ batch = scores.shape[0]
80
+ num_det = random.randint(0, 100)
81
+ batches = torch.randint(0, batch, (num_det,)).sort()[0].to(device)
82
+ idxs = torch.arange(100, 100 + num_det).to(device)
83
+ zeros = torch.zeros((num_det,), dtype=torch.int64).to(device)
84
+ selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], 0).T.contiguous()
85
+ selected_indices = selected_indices.to(torch.int64)
86
+ return selected_indices
87
+
88
+ @staticmethod
89
+ def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
90
+ return g.op("NonMaxSuppression", boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold)
91
+
92
+
93
+ class TRT_NMS(torch.autograd.Function):
94
+ '''TensorRT NMS operation'''
95
+ @staticmethod
96
+ def forward(
97
+ ctx,
98
+ boxes,
99
+ scores,
100
+ background_class=-1,
101
+ box_coding=1,
102
+ iou_threshold=0.45,
103
+ max_output_boxes=100,
104
+ plugin_version="1",
105
+ score_activation=0,
106
+ score_threshold=0.25,
107
+ ):
108
+
109
+ batch_size, num_boxes, num_classes = scores.shape
110
+ num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
111
+ det_boxes = torch.randn(batch_size, max_output_boxes, 4)
112
+ det_scores = torch.randn(batch_size, max_output_boxes)
113
+ det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
114
+ return num_det, det_boxes, det_scores, det_classes
115
+
116
+ @staticmethod
117
+ def symbolic(g,
118
+ boxes,
119
+ scores,
120
+ background_class=-1,
121
+ box_coding=1,
122
+ iou_threshold=0.45,
123
+ max_output_boxes=100,
124
+ plugin_version="1",
125
+ score_activation=0,
126
+ score_threshold=0.25):
127
+ out = g.op("TRT::EfficientNMS_TRT",
128
+ boxes,
129
+ scores,
130
+ background_class_i=background_class,
131
+ box_coding_i=box_coding,
132
+ iou_threshold_f=iou_threshold,
133
+ max_output_boxes_i=max_output_boxes,
134
+ plugin_version_s=plugin_version,
135
+ score_activation_i=score_activation,
136
+ score_threshold_f=score_threshold,
137
+ outputs=4)
138
+ nums, boxes, scores, classes = out
139
+ return nums, boxes, scores, classes
140
+
141
+
142
+ class ONNX_ORT(nn.Module):
143
+ '''onnx module with ONNX-Runtime NMS operation.'''
144
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
145
+ super().__init__()
146
+ self.device = device if device else torch.device("cpu")
147
+ self.max_obj = torch.tensor([max_obj]).to(device)
148
+ self.iou_threshold = torch.tensor([iou_thres]).to(device)
149
+ self.score_threshold = torch.tensor([score_thres]).to(device)
150
+ self.max_wh = max_wh # if max_wh != 0 : non-agnostic else : agnostic
151
+ self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
152
+ dtype=torch.float32,
153
+ device=self.device)
154
+ self.n_classes=n_classes
155
+
156
+ def forward(self, x):
157
+ ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
158
+ ## thanks https://github.com/thaitc-hust
159
+ if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
160
+ x = x[1]
161
+ x = x.permute(0, 2, 1)
162
+ bboxes_x = x[..., 0:1]
163
+ bboxes_y = x[..., 1:2]
164
+ bboxes_w = x[..., 2:3]
165
+ bboxes_h = x[..., 3:4]
166
+ bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
167
+ bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
168
+ obj_conf = x[..., 4:]
169
+ scores = obj_conf
170
+ bboxes @= self.convert_matrix
171
+ max_score, category_id = scores.max(2, keepdim=True)
172
+ dis = category_id.float() * self.max_wh
173
+ nmsbox = bboxes + dis
174
+ max_score_tp = max_score.transpose(1, 2).contiguous()
175
+ selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
176
+ X, Y = selected_indices[:, 0], selected_indices[:, 2]
177
+ selected_boxes = bboxes[X, Y, :]
178
+ selected_categories = category_id[X, Y, :].float()
179
+ selected_scores = max_score[X, Y, :]
180
+ X = X.unsqueeze(1).float()
181
+ return torch.cat([X, selected_boxes, selected_categories, selected_scores], 1)
182
+
183
+
184
+ class ONNX_TRT(nn.Module):
185
+ '''onnx module with TensorRT NMS operation.'''
186
+ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
187
+ super().__init__()
188
+ assert max_wh is None
189
+ self.device = device if device else torch.device('cpu')
190
+ self.background_class = -1,
191
+ self.box_coding = 1,
192
+ self.iou_threshold = iou_thres
193
+ self.max_obj = max_obj
194
+ self.plugin_version = '1'
195
+ self.score_activation = 0
196
+ self.score_threshold = score_thres
197
+ self.n_classes=n_classes
198
+
199
+ def forward(self, x):
200
+ ## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
201
+ ## thanks https://github.com/thaitc-hust
202
+ if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
203
+ x = x[1]
204
+ x = x.permute(0, 2, 1)
205
+ bboxes_x = x[..., 0:1]
206
+ bboxes_y = x[..., 1:2]
207
+ bboxes_w = x[..., 2:3]
208
+ bboxes_h = x[..., 3:4]
209
+ bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
210
+ bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
211
+ obj_conf = x[..., 4:]
212
+ scores = obj_conf
213
+ num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(bboxes, scores, self.background_class, self.box_coding,
214
+ self.iou_threshold, self.max_obj,
215
+ self.plugin_version, self.score_activation,
216
+ self.score_threshold)
217
+ return num_det, det_boxes, det_scores, det_classes
218
+
219
+ class End2End(nn.Module):
220
+ '''export onnx or tensorrt model with NMS operation.'''
221
+ def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
222
+ super().__init__()
223
+ device = device if device else torch.device('cpu')
224
+ assert isinstance(max_wh,(int)) or max_wh is None
225
+ self.model = model.to(device)
226
+ self.model.model[-1].end2end = True
227
+ self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
228
+ self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
229
+ self.end2end.eval()
230
+
231
+ def forward(self, x):
232
+ x = self.model(x)
233
+ x = self.end2end(x)
234
+ return x
235
+
236
+
237
+ def attempt_load(weights, device=None, inplace=True, fuse=True):
238
+ # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
239
+ from models.yolo import Detect, Model
240
+
241
+ model = Ensemble()
242
+ for w in weights if isinstance(weights, list) else [weights]:
243
+ ckpt = torch.load(attempt_download(w), map_location='cpu') # load
244
+ ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
245
+
246
+ # Model compatibility updates
247
+ if not hasattr(ckpt, 'stride'):
248
+ ckpt.stride = torch.tensor([32.])
249
+ if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
250
+ ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
251
+
252
+ model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
253
+
254
+ # Module compatibility updates
255
+ for m in model.modules():
256
+ t = type(m)
257
+ if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
258
+ m.inplace = inplace # torch 1.7.0 compatibility
259
+ # if t is Detect and not isinstance(m.anchor_grid, list):
260
+ # delattr(m, 'anchor_grid')
261
+ # setattr(m, 'anchor_grid', [torch.zeros(1)] * m.nl)
262
+ elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
263
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
264
+
265
+ # Return model
266
+ if len(model) == 1:
267
+ return model[-1]
268
+
269
+ # Return detection ensemble
270
+ print(f'Ensemble created with {weights}\n')
271
+ for k in 'names', 'nc', 'yaml':
272
+ setattr(model, k, getattr(model[0], k))
273
+ model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
274
+ assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
275
+ return model
yolov9/models/hub/anchors.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3 & YOLOv5
2
+ # Default anchors for COCO data
3
+
4
+
5
+ # P5 -------------------------------------------------------------------------------------------------------------------
6
+ # P5-640:
7
+ anchors_p5_640:
8
+ - [10,13, 16,30, 33,23] # P3/8
9
+ - [30,61, 62,45, 59,119] # P4/16
10
+ - [116,90, 156,198, 373,326] # P5/32
11
+
12
+
13
+ # P6 -------------------------------------------------------------------------------------------------------------------
14
+ # P6-640: thr=0.25: 0.9964 BPR, 5.54 anchors past thr, n=12, img_size=640, metric_all=0.281/0.716-mean/best, past_thr=0.469-mean: 9,11, 21,19, 17,41, 43,32, 39,70, 86,64, 65,131, 134,130, 120,265, 282,180, 247,354, 512,387
15
+ anchors_p6_640:
16
+ - [9,11, 21,19, 17,41] # P3/8
17
+ - [43,32, 39,70, 86,64] # P4/16
18
+ - [65,131, 134,130, 120,265] # P5/32
19
+ - [282,180, 247,354, 512,387] # P6/64
20
+
21
+ # P6-1280: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1280, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 19,27, 44,40, 38,94, 96,68, 86,152, 180,137, 140,301, 303,264, 238,542, 436,615, 739,380, 925,792
22
+ anchors_p6_1280:
23
+ - [19,27, 44,40, 38,94] # P3/8
24
+ - [96,68, 86,152, 180,137] # P4/16
25
+ - [140,301, 303,264, 238,542] # P5/32
26
+ - [436,615, 739,380, 925,792] # P6/64
27
+
28
+ # P6-1920: thr=0.25: 0.9950 BPR, 5.55 anchors past thr, n=12, img_size=1920, metric_all=0.281/0.714-mean/best, past_thr=0.468-mean: 28,41, 67,59, 57,141, 144,103, 129,227, 270,205, 209,452, 455,396, 358,812, 653,922, 1109,570, 1387,1187
29
+ anchors_p6_1920:
30
+ - [28,41, 67,59, 57,141] # P3/8
31
+ - [144,103, 129,227, 270,205] # P4/16
32
+ - [209,452, 455,396, 358,812] # P5/32
33
+ - [653,922, 1109,570, 1387,1187] # P6/64
34
+
35
+
36
+ # P7 -------------------------------------------------------------------------------------------------------------------
37
+ # P7-640: thr=0.25: 0.9962 BPR, 6.76 anchors past thr, n=15, img_size=640, metric_all=0.275/0.733-mean/best, past_thr=0.466-mean: 11,11, 13,30, 29,20, 30,46, 61,38, 39,92, 78,80, 146,66, 79,163, 149,150, 321,143, 157,303, 257,402, 359,290, 524,372
38
+ anchors_p7_640:
39
+ - [11,11, 13,30, 29,20] # P3/8
40
+ - [30,46, 61,38, 39,92] # P4/16
41
+ - [78,80, 146,66, 79,163] # P5/32
42
+ - [149,150, 321,143, 157,303] # P6/64
43
+ - [257,402, 359,290, 524,372] # P7/128
44
+
45
+ # P7-1280: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1280, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 19,22, 54,36, 32,77, 70,83, 138,71, 75,173, 165,159, 148,334, 375,151, 334,317, 251,626, 499,474, 750,326, 534,814, 1079,818
46
+ anchors_p7_1280:
47
+ - [19,22, 54,36, 32,77] # P3/8
48
+ - [70,83, 138,71, 75,173] # P4/16
49
+ - [165,159, 148,334, 375,151] # P5/32
50
+ - [334,317, 251,626, 499,474] # P6/64
51
+ - [750,326, 534,814, 1079,818] # P7/128
52
+
53
+ # P7-1920: thr=0.25: 0.9968 BPR, 6.71 anchors past thr, n=15, img_size=1920, metric_all=0.273/0.732-mean/best, past_thr=0.463-mean: 29,34, 81,55, 47,115, 105,124, 207,107, 113,259, 247,238, 222,500, 563,227, 501,476, 376,939, 749,711, 1126,489, 801,1222, 1618,1227
54
+ anchors_p7_1920:
55
+ - [29,34, 81,55, 47,115] # P3/8
56
+ - [105,124, 207,107, 113,259] # P4/16
57
+ - [247,238, 222,500, 563,227] # P5/32
58
+ - [501,476, 376,939, 749,711] # P6/64
59
+ - [1126,489, 801,1222, 1618,1227] # P7/128
yolov9/models/hub/yolov3-spp.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors:
8
+ - [10,13, 16,30, 33,23] # P3/8
9
+ - [30,61, 62,45, 59,119] # P4/16
10
+ - [116,90, 156,198, 373,326] # P5/32
11
+
12
+ # darknet53 backbone
13
+ backbone:
14
+ # [from, number, module, args]
15
+ [[-1, 1, Conv, [32, 3, 1]], # 0
16
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
17
+ [-1, 1, Bottleneck, [64]],
18
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
19
+ [-1, 2, Bottleneck, [128]],
20
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
21
+ [-1, 8, Bottleneck, [256]],
22
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
23
+ [-1, 8, Bottleneck, [512]],
24
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
25
+ [-1, 4, Bottleneck, [1024]], # 10
26
+ ]
27
+
28
+ # YOLOv3-SPP head
29
+ head:
30
+ [[-1, 1, Bottleneck, [1024, False]],
31
+ [-1, 1, SPP, [512, [5, 9, 13]]],
32
+ [-1, 1, Conv, [1024, 3, 1]],
33
+ [-1, 1, Conv, [512, 1, 1]],
34
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
35
+
36
+ [-2, 1, Conv, [256, 1, 1]],
37
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
39
+ [-1, 1, Bottleneck, [512, False]],
40
+ [-1, 1, Bottleneck, [512, False]],
41
+ [-1, 1, Conv, [256, 1, 1]],
42
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
43
+
44
+ [-2, 1, Conv, [128, 1, 1]],
45
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
46
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
47
+ [-1, 1, Bottleneck, [256, False]],
48
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
49
+
50
+ [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
51
+ ]
yolov9/models/hub/yolov3-tiny.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors:
8
+ - [10,14, 23,27, 37,58] # P4/16
9
+ - [81,82, 135,169, 344,319] # P5/32
10
+
11
+ # YOLOv3-tiny backbone
12
+ backbone:
13
+ # [from, number, module, args]
14
+ [[-1, 1, Conv, [16, 3, 1]], # 0
15
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 1-P1/2
16
+ [-1, 1, Conv, [32, 3, 1]],
17
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 3-P2/4
18
+ [-1, 1, Conv, [64, 3, 1]],
19
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 5-P3/8
20
+ [-1, 1, Conv, [128, 3, 1]],
21
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 7-P4/16
22
+ [-1, 1, Conv, [256, 3, 1]],
23
+ [-1, 1, nn.MaxPool2d, [2, 2, 0]], # 9-P5/32
24
+ [-1, 1, Conv, [512, 3, 1]],
25
+ [-1, 1, nn.ZeroPad2d, [[0, 1, 0, 1]]], # 11
26
+ [-1, 1, nn.MaxPool2d, [2, 1, 0]], # 12
27
+ ]
28
+
29
+ # YOLOv3-tiny head
30
+ head:
31
+ [[-1, 1, Conv, [1024, 3, 1]],
32
+ [-1, 1, Conv, [256, 1, 1]],
33
+ [-1, 1, Conv, [512, 3, 1]], # 15 (P5/32-large)
34
+
35
+ [-2, 1, Conv, [128, 1, 1]],
36
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
37
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
38
+ [-1, 1, Conv, [256, 3, 1]], # 19 (P4/16-medium)
39
+
40
+ [[19, 15], 1, Detect, [nc, anchors]], # Detect(P4, P5)
41
+ ]
yolov9/models/hub/yolov3.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv3
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors:
8
+ - [10,13, 16,30, 33,23] # P3/8
9
+ - [30,61, 62,45, 59,119] # P4/16
10
+ - [116,90, 156,198, 373,326] # P5/32
11
+
12
+ # darknet53 backbone
13
+ backbone:
14
+ # [from, number, module, args]
15
+ [[-1, 1, Conv, [32, 3, 1]], # 0
16
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
17
+ [-1, 1, Bottleneck, [64]],
18
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
19
+ [-1, 2, Bottleneck, [128]],
20
+ [-1, 1, Conv, [256, 3, 2]], # 5-P3/8
21
+ [-1, 8, Bottleneck, [256]],
22
+ [-1, 1, Conv, [512, 3, 2]], # 7-P4/16
23
+ [-1, 8, Bottleneck, [512]],
24
+ [-1, 1, Conv, [1024, 3, 2]], # 9-P5/32
25
+ [-1, 4, Bottleneck, [1024]], # 10
26
+ ]
27
+
28
+ # YOLOv3 head
29
+ head:
30
+ [[-1, 1, Bottleneck, [1024, False]],
31
+ [-1, 1, Conv, [512, 1, 1]],
32
+ [-1, 1, Conv, [1024, 3, 1]],
33
+ [-1, 1, Conv, [512, 1, 1]],
34
+ [-1, 1, Conv, [1024, 3, 1]], # 15 (P5/32-large)
35
+
36
+ [-2, 1, Conv, [256, 1, 1]],
37
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
38
+ [[-1, 8], 1, Concat, [1]], # cat backbone P4
39
+ [-1, 1, Bottleneck, [512, False]],
40
+ [-1, 1, Bottleneck, [512, False]],
41
+ [-1, 1, Conv, [256, 1, 1]],
42
+ [-1, 1, Conv, [512, 3, 1]], # 22 (P4/16-medium)
43
+
44
+ [-2, 1, Conv, [128, 1, 1]],
45
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
46
+ [[-1, 6], 1, Concat, [1]], # cat backbone P3
47
+ [-1, 1, Bottleneck, [256, False]],
48
+ [-1, 2, Bottleneck, [256, False]], # 27 (P3/8-small)
49
+
50
+ [[27, 22, 15], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
51
+ ]
yolov9/models/panoptic/gelan-c-pan.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, ADown, [256]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, ADown, [512]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, ADown, [512]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, ADown, [256]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, ADown, [512]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ # panoptic
79
+ [[15, 18, 21], 1, Panoptic, [nc, 93, 32, 256]], # Panoptic(P3, P4, P5)
80
+ ]
yolov9/models/panoptic/yolov7-af-pan.yaml ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv7
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ sem_nc: 93 # number of stuff classes
6
+ depth_multiple: 1.0 # model depth multiple
7
+ width_multiple: 1.0 # layer channel multiple
8
+ anchors: 3
9
+
10
+ # YOLOv7 backbone
11
+ backbone:
12
+ [[-1, 1, Conv, [32, 3, 1]], # 0
13
+
14
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
15
+ [-1, 1, Conv, [64, 3, 1]],
16
+
17
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
18
+ [-1, 1, Conv, [64, 1, 1]],
19
+ [-2, 1, Conv, [64, 1, 1]],
20
+ [-1, 1, Conv, [64, 3, 1]],
21
+ [-1, 1, Conv, [64, 3, 1]],
22
+ [-1, 1, Conv, [64, 3, 1]],
23
+ [-1, 1, Conv, [64, 3, 1]],
24
+ [[-1, -3, -5, -6], 1, Concat, [1]],
25
+ [-1, 1, Conv, [256, 1, 1]], # 11
26
+
27
+ [-1, 1, MP, []],
28
+ [-1, 1, Conv, [128, 1, 1]],
29
+ [-3, 1, Conv, [128, 1, 1]],
30
+ [-1, 1, Conv, [128, 3, 2]],
31
+ [[-1, -3], 1, Concat, [1]], # 16-P3/8
32
+ [-1, 1, Conv, [128, 1, 1]],
33
+ [-2, 1, Conv, [128, 1, 1]],
34
+ [-1, 1, Conv, [128, 3, 1]],
35
+ [-1, 1, Conv, [128, 3, 1]],
36
+ [-1, 1, Conv, [128, 3, 1]],
37
+ [-1, 1, Conv, [128, 3, 1]],
38
+ [[-1, -3, -5, -6], 1, Concat, [1]],
39
+ [-1, 1, Conv, [512, 1, 1]], # 24
40
+
41
+ [-1, 1, MP, []],
42
+ [-1, 1, Conv, [256, 1, 1]],
43
+ [-3, 1, Conv, [256, 1, 1]],
44
+ [-1, 1, Conv, [256, 3, 2]],
45
+ [[-1, -3], 1, Concat, [1]], # 29-P4/16
46
+ [-1, 1, Conv, [256, 1, 1]],
47
+ [-2, 1, Conv, [256, 1, 1]],
48
+ [-1, 1, Conv, [256, 3, 1]],
49
+ [-1, 1, Conv, [256, 3, 1]],
50
+ [-1, 1, Conv, [256, 3, 1]],
51
+ [-1, 1, Conv, [256, 3, 1]],
52
+ [[-1, -3, -5, -6], 1, Concat, [1]],
53
+ [-1, 1, Conv, [1024, 1, 1]], # 37
54
+
55
+ [-1, 1, MP, []],
56
+ [-1, 1, Conv, [512, 1, 1]],
57
+ [-3, 1, Conv, [512, 1, 1]],
58
+ [-1, 1, Conv, [512, 3, 2]],
59
+ [[-1, -3], 1, Concat, [1]], # 42-P5/32
60
+ [-1, 1, Conv, [256, 1, 1]],
61
+ [-2, 1, Conv, [256, 1, 1]],
62
+ [-1, 1, Conv, [256, 3, 1]],
63
+ [-1, 1, Conv, [256, 3, 1]],
64
+ [-1, 1, Conv, [256, 3, 1]],
65
+ [-1, 1, Conv, [256, 3, 1]],
66
+ [[-1, -3, -5, -6], 1, Concat, [1]],
67
+ [-1, 1, Conv, [1024, 1, 1]], # 50
68
+ ]
69
+
70
+ # yolov7 head
71
+ head:
72
+ [[-1, 1, SPPCSPC, [512]], # 51
73
+
74
+ [-1, 1, Conv, [256, 1, 1]],
75
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
76
+ [37, 1, Conv, [256, 1, 1]], # route backbone P4
77
+ [[-1, -2], 1, Concat, [1]],
78
+
79
+ [-1, 1, Conv, [256, 1, 1]],
80
+ [-2, 1, Conv, [256, 1, 1]],
81
+ [-1, 1, Conv, [128, 3, 1]],
82
+ [-1, 1, Conv, [128, 3, 1]],
83
+ [-1, 1, Conv, [128, 3, 1]],
84
+ [-1, 1, Conv, [128, 3, 1]],
85
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
86
+ [-1, 1, Conv, [256, 1, 1]], # 63
87
+
88
+ [-1, 1, Conv, [128, 1, 1]],
89
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
90
+ [24, 1, Conv, [128, 1, 1]], # route backbone P3
91
+ [[-1, -2], 1, Concat, [1]],
92
+
93
+ [-1, 1, Conv, [128, 1, 1]],
94
+ [-2, 1, Conv, [128, 1, 1]],
95
+ [-1, 1, Conv, [64, 3, 1]],
96
+ [-1, 1, Conv, [64, 3, 1]],
97
+ [-1, 1, Conv, [64, 3, 1]],
98
+ [-1, 1, Conv, [64, 3, 1]],
99
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
100
+ [-1, 1, Conv, [128, 1, 1]], # 75
101
+
102
+ [-1, 1, MP, []],
103
+ [-1, 1, Conv, [128, 1, 1]],
104
+ [-3, 1, Conv, [128, 1, 1]],
105
+ [-1, 1, Conv, [128, 3, 2]],
106
+ [[-1, -3, 63], 1, Concat, [1]],
107
+
108
+ [-1, 1, Conv, [256, 1, 1]],
109
+ [-2, 1, Conv, [256, 1, 1]],
110
+ [-1, 1, Conv, [128, 3, 1]],
111
+ [-1, 1, Conv, [128, 3, 1]],
112
+ [-1, 1, Conv, [128, 3, 1]],
113
+ [-1, 1, Conv, [128, 3, 1]],
114
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
115
+ [-1, 1, Conv, [256, 1, 1]], # 88
116
+
117
+ [-1, 1, MP, []],
118
+ [-1, 1, Conv, [256, 1, 1]],
119
+ [-3, 1, Conv, [256, 1, 1]],
120
+ [-1, 1, Conv, [256, 3, 2]],
121
+ [[-1, -3, 51], 1, Concat, [1]],
122
+
123
+ [-1, 1, Conv, [512, 1, 1]],
124
+ [-2, 1, Conv, [512, 1, 1]],
125
+ [-1, 1, Conv, [256, 3, 1]],
126
+ [-1, 1, Conv, [256, 3, 1]],
127
+ [-1, 1, Conv, [256, 3, 1]],
128
+ [-1, 1, Conv, [256, 3, 1]],
129
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
130
+ [-1, 1, Conv, [512, 1, 1]], # 101
131
+
132
+ [75, 1, Conv, [256, 3, 1]],
133
+ [88, 1, Conv, [512, 3, 1]],
134
+ [101, 1, Conv, [1024, 3, 1]],
135
+
136
+ [[102, 103, 104], 1, Panoptic, [nc, 93, 32, 256]], # Panoptic(P3, P4, P5)
137
+ ]
yolov9/models/segment/gelan-c-dseg.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, ADown, [256]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, ADown, [512]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, ADown, [512]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, ADown, [256]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, ADown, [512]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ [15, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 22
79
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
80
+ [-1, 1, Conv, [256, 3, 1]], # 24
81
+
82
+ # segment
83
+ [[15, 18, 21, 24], 1, DSegment, [nc, 32, 256]], # Segment(P3, P4, P5)
84
+ ]
yolov9/models/segment/gelan-c-seg.yaml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ # conv down
17
+ [-1, 1, Conv, [64, 3, 2]], # 0-P1/2
18
+
19
+ # conv down
20
+ [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
21
+
22
+ # elan-1 block
23
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 2
24
+
25
+ # avg-conv down
26
+ [-1, 1, ADown, [256]], # 3-P3/8
27
+
28
+ # elan-2 block
29
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 4
30
+
31
+ # avg-conv down
32
+ [-1, 1, ADown, [512]], # 5-P4/16
33
+
34
+ # elan-2 block
35
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 6
36
+
37
+ # avg-conv down
38
+ [-1, 1, ADown, [512]], # 7-P5/32
39
+
40
+ # elan-2 block
41
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 8
42
+ ]
43
+
44
+ # gelan head
45
+ head:
46
+ [
47
+ # elan-spp block
48
+ [-1, 1, SPPELAN, [512, 256]], # 9
49
+
50
+ # up-concat merge
51
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
52
+ [[-1, 6], 1, Concat, [1]], # cat backbone P4
53
+
54
+ # elan-2 block
55
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 12
56
+
57
+ # up-concat merge
58
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
59
+ [[-1, 4], 1, Concat, [1]], # cat backbone P3
60
+
61
+ # elan-2 block
62
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 15 (P3/8-small)
63
+
64
+ # avg-conv-down merge
65
+ [-1, 1, ADown, [256]],
66
+ [[-1, 12], 1, Concat, [1]], # cat head P4
67
+
68
+ # elan-2 block
69
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 18 (P4/16-medium)
70
+
71
+ # avg-conv-down merge
72
+ [-1, 1, ADown, [512]],
73
+ [[-1, 9], 1, Concat, [1]], # cat head P5
74
+
75
+ # elan-2 block
76
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 21 (P5/32-large)
77
+
78
+ # segment
79
+ [[15, 18, 21], 1, Segment, [nc, 32, 256]], # Segment(P3, P4, P5)
80
+ ]
yolov9/models/segment/yolov7-af-seg.yaml ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv7
2
+
3
+ # Parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ anchors: 3
8
+
9
+ # YOLOv7 backbone
10
+ backbone:
11
+ [[-1, 1, Conv, [32, 3, 1]], # 0
12
+
13
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
14
+ [-1, 1, Conv, [64, 3, 1]],
15
+
16
+ [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
17
+ [-1, 1, Conv, [64, 1, 1]],
18
+ [-2, 1, Conv, [64, 1, 1]],
19
+ [-1, 1, Conv, [64, 3, 1]],
20
+ [-1, 1, Conv, [64, 3, 1]],
21
+ [-1, 1, Conv, [64, 3, 1]],
22
+ [-1, 1, Conv, [64, 3, 1]],
23
+ [[-1, -3, -5, -6], 1, Concat, [1]],
24
+ [-1, 1, Conv, [256, 1, 1]], # 11
25
+
26
+ [-1, 1, MP, []],
27
+ [-1, 1, Conv, [128, 1, 1]],
28
+ [-3, 1, Conv, [128, 1, 1]],
29
+ [-1, 1, Conv, [128, 3, 2]],
30
+ [[-1, -3], 1, Concat, [1]], # 16-P3/8
31
+ [-1, 1, Conv, [128, 1, 1]],
32
+ [-2, 1, Conv, [128, 1, 1]],
33
+ [-1, 1, Conv, [128, 3, 1]],
34
+ [-1, 1, Conv, [128, 3, 1]],
35
+ [-1, 1, Conv, [128, 3, 1]],
36
+ [-1, 1, Conv, [128, 3, 1]],
37
+ [[-1, -3, -5, -6], 1, Concat, [1]],
38
+ [-1, 1, Conv, [512, 1, 1]], # 24
39
+
40
+ [-1, 1, MP, []],
41
+ [-1, 1, Conv, [256, 1, 1]],
42
+ [-3, 1, Conv, [256, 1, 1]],
43
+ [-1, 1, Conv, [256, 3, 2]],
44
+ [[-1, -3], 1, Concat, [1]], # 29-P4/16
45
+ [-1, 1, Conv, [256, 1, 1]],
46
+ [-2, 1, Conv, [256, 1, 1]],
47
+ [-1, 1, Conv, [256, 3, 1]],
48
+ [-1, 1, Conv, [256, 3, 1]],
49
+ [-1, 1, Conv, [256, 3, 1]],
50
+ [-1, 1, Conv, [256, 3, 1]],
51
+ [[-1, -3, -5, -6], 1, Concat, [1]],
52
+ [-1, 1, Conv, [1024, 1, 1]], # 37
53
+
54
+ [-1, 1, MP, []],
55
+ [-1, 1, Conv, [512, 1, 1]],
56
+ [-3, 1, Conv, [512, 1, 1]],
57
+ [-1, 1, Conv, [512, 3, 2]],
58
+ [[-1, -3], 1, Concat, [1]], # 42-P5/32
59
+ [-1, 1, Conv, [256, 1, 1]],
60
+ [-2, 1, Conv, [256, 1, 1]],
61
+ [-1, 1, Conv, [256, 3, 1]],
62
+ [-1, 1, Conv, [256, 3, 1]],
63
+ [-1, 1, Conv, [256, 3, 1]],
64
+ [-1, 1, Conv, [256, 3, 1]],
65
+ [[-1, -3, -5, -6], 1, Concat, [1]],
66
+ [-1, 1, Conv, [1024, 1, 1]], # 50
67
+ ]
68
+
69
+ # yolov7 head
70
+ head:
71
+ [[-1, 1, SPPCSPC, [512]], # 51
72
+
73
+ [-1, 1, Conv, [256, 1, 1]],
74
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
75
+ [37, 1, Conv, [256, 1, 1]], # route backbone P4
76
+ [[-1, -2], 1, Concat, [1]],
77
+
78
+ [-1, 1, Conv, [256, 1, 1]],
79
+ [-2, 1, Conv, [256, 1, 1]],
80
+ [-1, 1, Conv, [128, 3, 1]],
81
+ [-1, 1, Conv, [128, 3, 1]],
82
+ [-1, 1, Conv, [128, 3, 1]],
83
+ [-1, 1, Conv, [128, 3, 1]],
84
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
85
+ [-1, 1, Conv, [256, 1, 1]], # 63
86
+
87
+ [-1, 1, Conv, [128, 1, 1]],
88
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
89
+ [24, 1, Conv, [128, 1, 1]], # route backbone P3
90
+ [[-1, -2], 1, Concat, [1]],
91
+
92
+ [-1, 1, Conv, [128, 1, 1]],
93
+ [-2, 1, Conv, [128, 1, 1]],
94
+ [-1, 1, Conv, [64, 3, 1]],
95
+ [-1, 1, Conv, [64, 3, 1]],
96
+ [-1, 1, Conv, [64, 3, 1]],
97
+ [-1, 1, Conv, [64, 3, 1]],
98
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
99
+ [-1, 1, Conv, [128, 1, 1]], # 75
100
+
101
+ [-1, 1, MP, []],
102
+ [-1, 1, Conv, [128, 1, 1]],
103
+ [-3, 1, Conv, [128, 1, 1]],
104
+ [-1, 1, Conv, [128, 3, 2]],
105
+ [[-1, -3, 63], 1, Concat, [1]],
106
+
107
+ [-1, 1, Conv, [256, 1, 1]],
108
+ [-2, 1, Conv, [256, 1, 1]],
109
+ [-1, 1, Conv, [128, 3, 1]],
110
+ [-1, 1, Conv, [128, 3, 1]],
111
+ [-1, 1, Conv, [128, 3, 1]],
112
+ [-1, 1, Conv, [128, 3, 1]],
113
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
114
+ [-1, 1, Conv, [256, 1, 1]], # 88
115
+
116
+ [-1, 1, MP, []],
117
+ [-1, 1, Conv, [256, 1, 1]],
118
+ [-3, 1, Conv, [256, 1, 1]],
119
+ [-1, 1, Conv, [256, 3, 2]],
120
+ [[-1, -3, 51], 1, Concat, [1]],
121
+
122
+ [-1, 1, Conv, [512, 1, 1]],
123
+ [-2, 1, Conv, [512, 1, 1]],
124
+ [-1, 1, Conv, [256, 3, 1]],
125
+ [-1, 1, Conv, [256, 3, 1]],
126
+ [-1, 1, Conv, [256, 3, 1]],
127
+ [-1, 1, Conv, [256, 3, 1]],
128
+ [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
129
+ [-1, 1, Conv, [512, 1, 1]], # 101
130
+
131
+ [75, 1, Conv, [256, 3, 1]],
132
+ [88, 1, Conv, [512, 3, 1]],
133
+ [101, 1, Conv, [1024, 3, 1]],
134
+
135
+ [[102, 103, 104], 1, Segment, [nc, 32, 256]], # Segment(P3, P4, P5)
136
+ ]
yolov9/models/segment/yolov9-c-dseg.yaml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # YOLOv9
2
+
3
+ # parameters
4
+ nc: 80 # number of classes
5
+ depth_multiple: 1.0 # model depth multiple
6
+ width_multiple: 1.0 # layer channel multiple
7
+ #activation: nn.LeakyReLU(0.1)
8
+ #activation: nn.ReLU()
9
+
10
+ # anchors
11
+ anchors: 3
12
+
13
+ # gelan backbone
14
+ backbone:
15
+ [
16
+ [-1, 1, Silence, []],
17
+
18
+ # conv down
19
+ [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
20
+
21
+ # conv down
22
+ [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
23
+
24
+ # elan-1 block
25
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 3
26
+
27
+ # avg-conv down
28
+ [-1, 1, ADown, [256]], # 4-P3/8
29
+
30
+ # elan-2 block
31
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 5
32
+
33
+ # avg-conv down
34
+ [-1, 1, ADown, [512]], # 6-P4/16
35
+
36
+ # elan-2 block
37
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 7
38
+
39
+ # avg-conv down
40
+ [-1, 1, ADown, [512]], # 8-P5/32
41
+
42
+ # elan-2 block
43
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 9
44
+ ]
45
+
46
+ # YOLOv9 head
47
+ head:
48
+ [
49
+ # elan-spp block
50
+ [-1, 1, SPPELAN, [512, 256]], # 10
51
+
52
+ # up-concat merge
53
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
54
+ [[-1, 7], 1, Concat, [1]], # cat backbone P4
55
+
56
+ # elan-2 block
57
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
58
+
59
+ # up-concat merge
60
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
61
+ [[-1, 5], 1, Concat, [1]], # cat backbone P3
62
+
63
+ # elan-2 block
64
+ [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
65
+
66
+ # avg-conv-down merge
67
+ [-1, 1, ADown, [256]],
68
+ [[-1, 13], 1, Concat, [1]], # cat head P4
69
+
70
+ # elan-2 block
71
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
72
+
73
+ # avg-conv-down merge
74
+ [-1, 1, ADown, [512]],
75
+ [[-1, 10], 1, Concat, [1]], # cat head P5
76
+
77
+ # elan-2 block
78
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
79
+
80
+
81
+ # multi-level reversible auxiliary branch
82
+
83
+ # routing
84
+ [5, 1, CBLinear, [[256]]], # 23
85
+ [7, 1, CBLinear, [[256, 512]]], # 24
86
+ [9, 1, CBLinear, [[256, 512, 512]]], # 25
87
+
88
+ # conv down
89
+ [0, 1, Conv, [64, 3, 2]], # 26-P1/2
90
+
91
+ # conv down
92
+ [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
93
+
94
+ # elan-1 block
95
+ [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
96
+
97
+ # avg-conv down fuse
98
+ [-1, 1, ADown, [256]], # 29-P3/8
99
+ [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
100
+
101
+ # elan-2 block
102
+ [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
103
+
104
+ # avg-conv down fuse
105
+ [-1, 1, ADown, [512]], # 32-P4/16
106
+ [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
107
+
108
+ # elan-2 block
109
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
110
+
111
+ # avg-conv down fuse
112
+ [-1, 1, ADown, [512]], # 35-P5/32
113
+ [[25, -1], 1, CBFuse, [[2]]], # 36
114
+
115
+ # elan-2 block
116
+ [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
117
+
118
+ [31, 1, RepNCSPELAN4, [512, 256, 128, 2]], # 38
119
+
120
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
121
+ [-1, 1, Conv, [256, 3, 1]], # 40
122
+
123
+ [16, 1, RepNCSPELAN4, [256, 256, 128, 2]], # 41
124
+
125
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
126
+ [-1, 1, Conv, [256, 3, 1]], # 43
127
+
128
+ # segment
129
+ [[31, 34, 37, 16, 19, 22, 40, 43], 1, DualDSegment, [nc, 32, 256]], # Segment(P3, P4, P5)
130
+ ]
yolov9/models/tf.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from copy import deepcopy
4
+ from pathlib import Path
5
+
6
+ FILE = Path(__file__).resolve()
7
+ ROOT = FILE.parents[1] # YOLO root directory
8
+ if str(ROOT) not in sys.path:
9
+ sys.path.append(str(ROOT)) # add ROOT to PATH
10
+ # ROOT = ROOT.relative_to(Path.cwd()) # relative
11
+
12
+ import numpy as np
13
+ import tensorflow as tf
14
+ import torch
15
+ import torch.nn as nn
16
+ from tensorflow import keras
17
+
18
+ from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
19
+ DWConvTranspose2d, Focus, autopad)
20
+ from models.experimental import MixConv2d, attempt_load
21
+ from models.yolo import Detect, Segment
22
+ from utils.activations import SiLU
23
+ from utils.general import LOGGER, make_divisible, print_args
24
+
25
+
26
+ class TFBN(keras.layers.Layer):
27
+ # TensorFlow BatchNormalization wrapper
28
+ def __init__(self, w=None):
29
+ super().__init__()
30
+ self.bn = keras.layers.BatchNormalization(
31
+ beta_initializer=keras.initializers.Constant(w.bias.numpy()),
32
+ gamma_initializer=keras.initializers.Constant(w.weight.numpy()),
33
+ moving_mean_initializer=keras.initializers.Constant(w.running_mean.numpy()),
34
+ moving_variance_initializer=keras.initializers.Constant(w.running_var.numpy()),
35
+ epsilon=w.eps)
36
+
37
+ def call(self, inputs):
38
+ return self.bn(inputs)
39
+
40
+
41
+ class TFPad(keras.layers.Layer):
42
+ # Pad inputs in spatial dimensions 1 and 2
43
+ def __init__(self, pad):
44
+ super().__init__()
45
+ if isinstance(pad, int):
46
+ self.pad = tf.constant([[0, 0], [pad, pad], [pad, pad], [0, 0]])
47
+ else: # tuple/list
48
+ self.pad = tf.constant([[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]])
49
+
50
+ def call(self, inputs):
51
+ return tf.pad(inputs, self.pad, mode='constant', constant_values=0)
52
+
53
+
54
+ class TFConv(keras.layers.Layer):
55
+ # Standard convolution
56
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
57
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
58
+ super().__init__()
59
+ assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
60
+ # TensorFlow convolution padding is inconsistent with PyTorch (e.g. k=3 s=2 'SAME' padding)
61
+ # see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
62
+ conv = keras.layers.Conv2D(
63
+ filters=c2,
64
+ kernel_size=k,
65
+ strides=s,
66
+ padding='SAME' if s == 1 else 'VALID',
67
+ use_bias=not hasattr(w, 'bn'),
68
+ kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
69
+ bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
70
+ self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
71
+ self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
72
+ self.act = activations(w.act) if act else tf.identity
73
+
74
+ def call(self, inputs):
75
+ return self.act(self.bn(self.conv(inputs)))
76
+
77
+
78
+ class TFDWConv(keras.layers.Layer):
79
+ # Depthwise convolution
80
+ def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
81
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
82
+ super().__init__()
83
+ assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
84
+ conv = keras.layers.DepthwiseConv2D(
85
+ kernel_size=k,
86
+ depth_multiplier=c2 // c1,
87
+ strides=s,
88
+ padding='SAME' if s == 1 else 'VALID',
89
+ use_bias=not hasattr(w, 'bn'),
90
+ depthwise_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
91
+ bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
92
+ self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
93
+ self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
94
+ self.act = activations(w.act) if act else tf.identity
95
+
96
+ def call(self, inputs):
97
+ return self.act(self.bn(self.conv(inputs)))
98
+
99
+
100
+ class TFDWConvTranspose2d(keras.layers.Layer):
101
+ # Depthwise ConvTranspose2d
102
+ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
103
+ # ch_in, ch_out, weights, kernel, stride, padding, groups
104
+ super().__init__()
105
+ assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
106
+ assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
107
+ weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
108
+ self.c1 = c1
109
+ self.conv = [
110
+ keras.layers.Conv2DTranspose(filters=1,
111
+ kernel_size=k,
112
+ strides=s,
113
+ padding='VALID',
114
+ output_padding=p2,
115
+ use_bias=True,
116
+ kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
117
+ bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
118
+
119
+ def call(self, inputs):
120
+ return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
121
+
122
+
123
+ class TFFocus(keras.layers.Layer):
124
+ # Focus wh information into c-space
125
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
126
+ # ch_in, ch_out, kernel, stride, padding, groups
127
+ super().__init__()
128
+ self.conv = TFConv(c1 * 4, c2, k, s, p, g, act, w.conv)
129
+
130
+ def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
131
+ # inputs = inputs / 255 # normalize 0-255 to 0-1
132
+ inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
133
+ return self.conv(tf.concat(inputs, 3))
134
+
135
+
136
+ class TFBottleneck(keras.layers.Layer):
137
+ # Standard bottleneck
138
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, w=None): # ch_in, ch_out, shortcut, groups, expansion
139
+ super().__init__()
140
+ c_ = int(c2 * e) # hidden channels
141
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
142
+ self.cv2 = TFConv(c_, c2, 3, 1, g=g, w=w.cv2)
143
+ self.add = shortcut and c1 == c2
144
+
145
+ def call(self, inputs):
146
+ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
147
+
148
+
149
+ class TFCrossConv(keras.layers.Layer):
150
+ # Cross Convolution
151
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False, w=None):
152
+ super().__init__()
153
+ c_ = int(c2 * e) # hidden channels
154
+ self.cv1 = TFConv(c1, c_, (1, k), (1, s), w=w.cv1)
155
+ self.cv2 = TFConv(c_, c2, (k, 1), (s, 1), g=g, w=w.cv2)
156
+ self.add = shortcut and c1 == c2
157
+
158
+ def call(self, inputs):
159
+ return inputs + self.cv2(self.cv1(inputs)) if self.add else self.cv2(self.cv1(inputs))
160
+
161
+
162
+ class TFConv2d(keras.layers.Layer):
163
+ # Substitution for PyTorch nn.Conv2D
164
+ def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
165
+ super().__init__()
166
+ assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
167
+ self.conv = keras.layers.Conv2D(filters=c2,
168
+ kernel_size=k,
169
+ strides=s,
170
+ padding='VALID',
171
+ use_bias=bias,
172
+ kernel_initializer=keras.initializers.Constant(
173
+ w.weight.permute(2, 3, 1, 0).numpy()),
174
+ bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
175
+
176
+ def call(self, inputs):
177
+ return self.conv(inputs)
178
+
179
+
180
+ class TFBottleneckCSP(keras.layers.Layer):
181
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
182
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
183
+ # ch_in, ch_out, number, shortcut, groups, expansion
184
+ super().__init__()
185
+ c_ = int(c2 * e) # hidden channels
186
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
187
+ self.cv2 = TFConv2d(c1, c_, 1, 1, bias=False, w=w.cv2)
188
+ self.cv3 = TFConv2d(c_, c_, 1, 1, bias=False, w=w.cv3)
189
+ self.cv4 = TFConv(2 * c_, c2, 1, 1, w=w.cv4)
190
+ self.bn = TFBN(w.bn)
191
+ self.act = lambda x: keras.activations.swish(x)
192
+ self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
193
+
194
+ def call(self, inputs):
195
+ y1 = self.cv3(self.m(self.cv1(inputs)))
196
+ y2 = self.cv2(inputs)
197
+ return self.cv4(self.act(self.bn(tf.concat((y1, y2), axis=3))))
198
+
199
+
200
+ class TFC3(keras.layers.Layer):
201
+ # CSP Bottleneck with 3 convolutions
202
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
203
+ # ch_in, ch_out, number, shortcut, groups, expansion
204
+ super().__init__()
205
+ c_ = int(c2 * e) # hidden channels
206
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
207
+ self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
208
+ self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
209
+ self.m = keras.Sequential([TFBottleneck(c_, c_, shortcut, g, e=1.0, w=w.m[j]) for j in range(n)])
210
+
211
+ def call(self, inputs):
212
+ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
213
+
214
+
215
+ class TFC3x(keras.layers.Layer):
216
+ # 3 module with cross-convolutions
217
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, w=None):
218
+ # ch_in, ch_out, number, shortcut, groups, expansion
219
+ super().__init__()
220
+ c_ = int(c2 * e) # hidden channels
221
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
222
+ self.cv2 = TFConv(c1, c_, 1, 1, w=w.cv2)
223
+ self.cv3 = TFConv(2 * c_, c2, 1, 1, w=w.cv3)
224
+ self.m = keras.Sequential([
225
+ TFCrossConv(c_, c_, k=3, s=1, g=g, e=1.0, shortcut=shortcut, w=w.m[j]) for j in range(n)])
226
+
227
+ def call(self, inputs):
228
+ return self.cv3(tf.concat((self.m(self.cv1(inputs)), self.cv2(inputs)), axis=3))
229
+
230
+
231
+ class TFSPP(keras.layers.Layer):
232
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
233
+ def __init__(self, c1, c2, k=(5, 9, 13), w=None):
234
+ super().__init__()
235
+ c_ = c1 // 2 # hidden channels
236
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
237
+ self.cv2 = TFConv(c_ * (len(k) + 1), c2, 1, 1, w=w.cv2)
238
+ self.m = [keras.layers.MaxPool2D(pool_size=x, strides=1, padding='SAME') for x in k]
239
+
240
+ def call(self, inputs):
241
+ x = self.cv1(inputs)
242
+ return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
243
+
244
+
245
+ class TFSPPF(keras.layers.Layer):
246
+ # Spatial pyramid pooling-Fast layer
247
+ def __init__(self, c1, c2, k=5, w=None):
248
+ super().__init__()
249
+ c_ = c1 // 2 # hidden channels
250
+ self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
251
+ self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
252
+ self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
253
+
254
+ def call(self, inputs):
255
+ x = self.cv1(inputs)
256
+ y1 = self.m(x)
257
+ y2 = self.m(y1)
258
+ return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
259
+
260
+
261
+ class TFDetect(keras.layers.Layer):
262
+ # TF YOLO Detect layer
263
+ def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
264
+ super().__init__()
265
+ self.stride = tf.convert_to_tensor(w.stride.numpy(), dtype=tf.float32)
266
+ self.nc = nc # number of classes
267
+ self.no = nc + 5 # number of outputs per anchor
268
+ self.nl = len(anchors) # number of detection layers
269
+ self.na = len(anchors[0]) // 2 # number of anchors
270
+ self.grid = [tf.zeros(1)] * self.nl # init grid
271
+ self.anchors = tf.convert_to_tensor(w.anchors.numpy(), dtype=tf.float32)
272
+ self.anchor_grid = tf.reshape(self.anchors * tf.reshape(self.stride, [self.nl, 1, 1]), [self.nl, 1, -1, 1, 2])
273
+ self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)]
274
+ self.training = False # set to False after building model
275
+ self.imgsz = imgsz
276
+ for i in range(self.nl):
277
+ ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
278
+ self.grid[i] = self._make_grid(nx, ny)
279
+
280
+ def call(self, inputs):
281
+ z = [] # inference output
282
+ x = []
283
+ for i in range(self.nl):
284
+ x.append(self.m[i](inputs[i]))
285
+ # x(bs,20,20,255) to x(bs,3,20,20,85)
286
+ ny, nx = self.imgsz[0] // self.stride[i], self.imgsz[1] // self.stride[i]
287
+ x[i] = tf.reshape(x[i], [-1, ny * nx, self.na, self.no])
288
+
289
+ if not self.training: # inference
290
+ y = x[i]
291
+ grid = tf.transpose(self.grid[i], [0, 2, 1, 3]) - 0.5
292
+ anchor_grid = tf.transpose(self.anchor_grid[i], [0, 2, 1, 3]) * 4
293
+ xy = (tf.sigmoid(y[..., 0:2]) * 2 + grid) * self.stride[i] # xy
294
+ wh = tf.sigmoid(y[..., 2:4]) ** 2 * anchor_grid
295
+ # Normalize xywh to 0-1 to reduce calibration error
296
+ xy /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
297
+ wh /= tf.constant([[self.imgsz[1], self.imgsz[0]]], dtype=tf.float32)
298
+ y = tf.concat([xy, wh, tf.sigmoid(y[..., 4:5 + self.nc]), y[..., 5 + self.nc:]], -1)
299
+ z.append(tf.reshape(y, [-1, self.na * ny * nx, self.no]))
300
+
301
+ return tf.transpose(x, [0, 2, 1, 3]) if self.training else (tf.concat(z, 1),)
302
+
303
+ @staticmethod
304
+ def _make_grid(nx=20, ny=20):
305
+ # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
306
+ # return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
307
+ xv, yv = tf.meshgrid(tf.range(nx), tf.range(ny))
308
+ return tf.cast(tf.reshape(tf.stack([xv, yv], 2), [1, 1, ny * nx, 2]), dtype=tf.float32)
309
+
310
+
311
+ class TFSegment(TFDetect):
312
+ # YOLO Segment head for segmentation models
313
+ def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), imgsz=(640, 640), w=None):
314
+ super().__init__(nc, anchors, ch, imgsz, w)
315
+ self.nm = nm # number of masks
316
+ self.npr = npr # number of protos
317
+ self.no = 5 + nc + self.nm # number of outputs per anchor
318
+ self.m = [TFConv2d(x, self.no * self.na, 1, w=w.m[i]) for i, x in enumerate(ch)] # output conv
319
+ self.proto = TFProto(ch[0], self.npr, self.nm, w=w.proto) # protos
320
+ self.detect = TFDetect.call
321
+
322
+ def call(self, x):
323
+ p = self.proto(x[0])
324
+ # p = TFUpsample(None, scale_factor=4, mode='nearest')(self.proto(x[0])) # (optional) full-size protos
325
+ p = tf.transpose(p, [0, 3, 1, 2]) # from shape(1,160,160,32) to shape(1,32,160,160)
326
+ x = self.detect(self, x)
327
+ return (x, p) if self.training else (x[0], p)
328
+
329
+
330
+ class TFProto(keras.layers.Layer):
331
+
332
+ def __init__(self, c1, c_=256, c2=32, w=None):
333
+ super().__init__()
334
+ self.cv1 = TFConv(c1, c_, k=3, w=w.cv1)
335
+ self.upsample = TFUpsample(None, scale_factor=2, mode='nearest')
336
+ self.cv2 = TFConv(c_, c_, k=3, w=w.cv2)
337
+ self.cv3 = TFConv(c_, c2, w=w.cv3)
338
+
339
+ def call(self, inputs):
340
+ return self.cv3(self.cv2(self.upsample(self.cv1(inputs))))
341
+
342
+
343
+ class TFUpsample(keras.layers.Layer):
344
+ # TF version of torch.nn.Upsample()
345
+ def __init__(self, size, scale_factor, mode, w=None): # warning: all arguments needed including 'w'
346
+ super().__init__()
347
+ assert scale_factor % 2 == 0, "scale_factor must be multiple of 2"
348
+ self.upsample = lambda x: tf.image.resize(x, (x.shape[1] * scale_factor, x.shape[2] * scale_factor), mode)
349
+ # self.upsample = keras.layers.UpSampling2D(size=scale_factor, interpolation=mode)
350
+ # with default arguments: align_corners=False, half_pixel_centers=False
351
+ # self.upsample = lambda x: tf.raw_ops.ResizeNearestNeighbor(images=x,
352
+ # size=(x.shape[1] * 2, x.shape[2] * 2))
353
+
354
+ def call(self, inputs):
355
+ return self.upsample(inputs)
356
+
357
+
358
+ class TFConcat(keras.layers.Layer):
359
+ # TF version of torch.concat()
360
+ def __init__(self, dimension=1, w=None):
361
+ super().__init__()
362
+ assert dimension == 1, "convert only NCHW to NHWC concat"
363
+ self.d = 3
364
+
365
+ def call(self, inputs):
366
+ return tf.concat(inputs, self.d)
367
+
368
+
369
+ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
370
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
371
+ anchors, nc, gd, gw = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple']
372
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
373
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
374
+
375
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
376
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
377
+ m_str = m
378
+ m = eval(m) if isinstance(m, str) else m # eval strings
379
+ for j, a in enumerate(args):
380
+ try:
381
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
382
+ except NameError:
383
+ pass
384
+
385
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
386
+ if m in [
387
+ nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
388
+ BottleneckCSP, C3, C3x]:
389
+ c1, c2 = ch[f], args[0]
390
+ c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
391
+
392
+ args = [c1, c2, *args[1:]]
393
+ if m in [BottleneckCSP, C3, C3x]:
394
+ args.insert(2, n)
395
+ n = 1
396
+ elif m is nn.BatchNorm2d:
397
+ args = [ch[f]]
398
+ elif m is Concat:
399
+ c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
400
+ elif m in [Detect, Segment]:
401
+ args.append([ch[x + 1] for x in f])
402
+ if isinstance(args[1], int): # number of anchors
403
+ args[1] = [list(range(args[1] * 2))] * len(f)
404
+ if m is Segment:
405
+ args[3] = make_divisible(args[3] * gw, 8)
406
+ args.append(imgsz)
407
+ else:
408
+ c2 = ch[f]
409
+
410
+ tf_m = eval('TF' + m_str.replace('nn.', ''))
411
+ m_ = keras.Sequential([tf_m(*args, w=model.model[i][j]) for j in range(n)]) if n > 1 \
412
+ else tf_m(*args, w=model.model[i]) # module
413
+
414
+ torch_m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
415
+ t = str(m)[8:-2].replace('__main__.', '') # module type
416
+ np = sum(x.numel() for x in torch_m_.parameters()) # number params
417
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
418
+ LOGGER.info(f'{i:>3}{str(f):>18}{str(n):>3}{np:>10} {t:<40}{str(args):<30}') # print
419
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
420
+ layers.append(m_)
421
+ ch.append(c2)
422
+ return keras.Sequential(layers), sorted(save)
423
+
424
+
425
+ class TFModel:
426
+ # TF YOLO model
427
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, model=None, imgsz=(640, 640)): # model, channels, classes
428
+ super().__init__()
429
+ if isinstance(cfg, dict):
430
+ self.yaml = cfg # model dict
431
+ else: # is *.yaml
432
+ import yaml # for torch hub
433
+ self.yaml_file = Path(cfg).name
434
+ with open(cfg) as f:
435
+ self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict
436
+
437
+ # Define model
438
+ if nc and nc != self.yaml['nc']:
439
+ LOGGER.info(f"Overriding {cfg} nc={self.yaml['nc']} with nc={nc}")
440
+ self.yaml['nc'] = nc # override yaml value
441
+ self.model, self.savelist = parse_model(deepcopy(self.yaml), ch=[ch], model=model, imgsz=imgsz)
442
+
443
+ def predict(self,
444
+ inputs,
445
+ tf_nms=False,
446
+ agnostic_nms=False,
447
+ topk_per_class=100,
448
+ topk_all=100,
449
+ iou_thres=0.45,
450
+ conf_thres=0.25):
451
+ y = [] # outputs
452
+ x = inputs
453
+ for m in self.model.layers:
454
+ if m.f != -1: # if not from previous layer
455
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
456
+
457
+ x = m(x) # run
458
+ y.append(x if m.i in self.savelist else None) # save output
459
+
460
+ # Add TensorFlow NMS
461
+ if tf_nms:
462
+ boxes = self._xywh2xyxy(x[0][..., :4])
463
+ probs = x[0][:, :, 4:5]
464
+ classes = x[0][:, :, 5:]
465
+ scores = probs * classes
466
+ if agnostic_nms:
467
+ nms = AgnosticNMS()((boxes, classes, scores), topk_all, iou_thres, conf_thres)
468
+ else:
469
+ boxes = tf.expand_dims(boxes, 2)
470
+ nms = tf.image.combined_non_max_suppression(boxes,
471
+ scores,
472
+ topk_per_class,
473
+ topk_all,
474
+ iou_thres,
475
+ conf_thres,
476
+ clip_boxes=False)
477
+ return (nms,)
478
+ return x # output [1,6300,85] = [xywh, conf, class0, class1, ...]
479
+ # x = x[0] # [x(1,6300,85), ...] to x(6300,85)
480
+ # xywh = x[..., :4] # x(6300,4) boxes
481
+ # conf = x[..., 4:5] # x(6300,1) confidences
482
+ # cls = tf.reshape(tf.cast(tf.argmax(x[..., 5:], axis=1), tf.float32), (-1, 1)) # x(6300,1) classes
483
+ # return tf.concat([conf, cls, xywh], 1)
484
+
485
+ @staticmethod
486
+ def _xywh2xyxy(xywh):
487
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
488
+ x, y, w, h = tf.split(xywh, num_or_size_splits=4, axis=-1)
489
+ return tf.concat([x - w / 2, y - h / 2, x + w / 2, y + h / 2], axis=-1)
490
+
491
+
492
+ class AgnosticNMS(keras.layers.Layer):
493
+ # TF Agnostic NMS
494
+ def call(self, input, topk_all, iou_thres, conf_thres):
495
+ # wrap map_fn to avoid TypeSpec related error https://stackoverflow.com/a/65809989/3036450
496
+ return tf.map_fn(lambda x: self._nms(x, topk_all, iou_thres, conf_thres),
497
+ input,
498
+ fn_output_signature=(tf.float32, tf.float32, tf.float32, tf.int32),
499
+ name='agnostic_nms')
500
+
501
+ @staticmethod
502
+ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
503
+ boxes, classes, scores = x
504
+ class_inds = tf.cast(tf.argmax(classes, axis=-1), tf.float32)
505
+ scores_inp = tf.reduce_max(scores, -1)
506
+ selected_inds = tf.image.non_max_suppression(boxes,
507
+ scores_inp,
508
+ max_output_size=topk_all,
509
+ iou_threshold=iou_thres,
510
+ score_threshold=conf_thres)
511
+ selected_boxes = tf.gather(boxes, selected_inds)
512
+ padded_boxes = tf.pad(selected_boxes,
513
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]], [0, 0]],
514
+ mode="CONSTANT",
515
+ constant_values=0.0)
516
+ selected_scores = tf.gather(scores_inp, selected_inds)
517
+ padded_scores = tf.pad(selected_scores,
518
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
519
+ mode="CONSTANT",
520
+ constant_values=-1.0)
521
+ selected_classes = tf.gather(class_inds, selected_inds)
522
+ padded_classes = tf.pad(selected_classes,
523
+ paddings=[[0, topk_all - tf.shape(selected_boxes)[0]]],
524
+ mode="CONSTANT",
525
+ constant_values=-1.0)
526
+ valid_detections = tf.shape(selected_inds)[0]
527
+ return padded_boxes, padded_scores, padded_classes, valid_detections
528
+
529
+
530
+ def activations(act=nn.SiLU):
531
+ # Returns TF activation from input PyTorch activation
532
+ if isinstance(act, nn.LeakyReLU):
533
+ return lambda x: keras.activations.relu(x, alpha=0.1)
534
+ elif isinstance(act, nn.Hardswish):
535
+ return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
536
+ elif isinstance(act, (nn.SiLU, SiLU)):
537
+ return lambda x: keras.activations.swish(x)
538
+ else:
539
+ raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')
540
+
541
+
542
+ def representative_dataset_gen(dataset, ncalib=100):
543
+ # Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
544
+ for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):
545
+ im = np.transpose(img, [1, 2, 0])
546
+ im = np.expand_dims(im, axis=0).astype(np.float32)
547
+ im /= 255
548
+ yield [im]
549
+ if n >= ncalib:
550
+ break
551
+
552
+
553
+ def run(
554
+ weights=ROOT / 'yolo.pt', # weights path
555
+ imgsz=(640, 640), # inference size h,w
556
+ batch_size=1, # batch size
557
+ dynamic=False, # dynamic batch size
558
+ ):
559
+ # PyTorch model
560
+ im = torch.zeros((batch_size, 3, *imgsz)) # BCHW image
561
+ model = attempt_load(weights, device=torch.device('cpu'), inplace=True, fuse=False)
562
+ _ = model(im) # inference
563
+ model.info()
564
+
565
+ # TensorFlow model
566
+ im = tf.zeros((batch_size, *imgsz, 3)) # BHWC image
567
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
568
+ _ = tf_model.predict(im) # inference
569
+
570
+ # Keras model
571
+ im = keras.Input(shape=(*imgsz, 3), batch_size=None if dynamic else batch_size)
572
+ keras_model = keras.Model(inputs=im, outputs=tf_model.predict(im))
573
+ keras_model.summary()
574
+
575
+ LOGGER.info('PyTorch, TensorFlow and Keras models successfully verified.\nUse export.py for TF model export.')
576
+
577
+
578
+ def parse_opt():
579
+ parser = argparse.ArgumentParser()
580
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo.pt', help='weights path')
581
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
582
+ parser.add_argument('--batch-size', type=int, default=1, help='batch size')
583
+ parser.add_argument('--dynamic', action='store_true', help='dynamic batch size')
584
+ opt = parser.parse_args()
585
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
586
+ print_args(vars(opt))
587
+ return opt
588
+
589
+
590
+ def main(opt):
591
+ run(**vars(opt))
592
+
593
+
594
+ if __name__ == "__main__":
595
+ opt = parse_opt()
596
+ main(opt)
yolov9/models/yolo.py ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from copy import deepcopy
6
+ from pathlib import Path
7
+
8
+ FILE = Path(__file__).resolve()
9
+ ROOT = FILE.parents[1] # YOLO root directory
10
+ if str(ROOT) not in sys.path:
11
+ sys.path.append(str(ROOT)) # add ROOT to PATH
12
+ if platform.system() != 'Windows':
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import *
16
+ from models.experimental import *
17
+ from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
18
+ from utils.plots import feature_visualization
19
+ from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
20
+ time_sync)
21
+ from utils.tal.anchor_generator import make_anchors, dist2bbox
22
+
23
+ try:
24
+ import thop # for FLOPs computation
25
+ except ImportError:
26
+ thop = None
27
+
28
+
29
+ class Detect(nn.Module):
30
+ # YOLO Detect head for detection models
31
+ dynamic = False # force grid reconstruction
32
+ export = False # export mode
33
+ shape = None
34
+ anchors = torch.empty(0) # init
35
+ strides = torch.empty(0) # init
36
+
37
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
38
+ super().__init__()
39
+ self.nc = nc # number of classes
40
+ self.nl = len(ch) # number of detection layers
41
+ self.reg_max = 16
42
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
43
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
44
+ self.stride = torch.zeros(self.nl) # strides computed during build
45
+
46
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
47
+ self.cv2 = nn.ModuleList(
48
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch)
49
+ self.cv3 = nn.ModuleList(
50
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
51
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ shape = x[0].shape # BCHW
55
+ for i in range(self.nl):
56
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
57
+ if self.training:
58
+ return x
59
+ elif self.dynamic or self.shape != shape:
60
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
61
+ self.shape = shape
62
+
63
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
64
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
65
+ y = torch.cat((dbox, cls.sigmoid()), 1)
66
+ return y if self.export else (y, x)
67
+
68
+ def bias_init(self):
69
+ # Initialize Detect() biases, WARNING: requires stride availability
70
+ m = self # self.model[-1] # Detect() module
71
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
72
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
73
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
74
+ a[-1].bias.data[:] = 1.0 # box
75
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
76
+
77
+
78
+ class DDetect(nn.Module):
79
+ # YOLO Detect head for detection models
80
+ dynamic = False # force grid reconstruction
81
+ export = False # export mode
82
+ shape = None
83
+ anchors = torch.empty(0) # init
84
+ strides = torch.empty(0) # init
85
+
86
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
87
+ super().__init__()
88
+ self.nc = nc # number of classes
89
+ self.nl = len(ch) # number of detection layers
90
+ self.reg_max = 16
91
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
92
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
93
+ self.stride = torch.zeros(self.nl) # strides computed during build
94
+
95
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
96
+ self.cv2 = nn.ModuleList(
97
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch)
98
+ self.cv3 = nn.ModuleList(
99
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
100
+ self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
101
+
102
+ def forward(self, x):
103
+ shape = x[0].shape # BCHW
104
+ for i in range(self.nl):
105
+ x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
106
+ if self.training:
107
+ return x
108
+ elif self.dynamic or self.shape != shape:
109
+ self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
110
+ self.shape = shape
111
+
112
+ box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
113
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
114
+ y = torch.cat((dbox, cls.sigmoid()), 1)
115
+ return y if self.export else (y, x)
116
+
117
+ def bias_init(self):
118
+ # Initialize Detect() biases, WARNING: requires stride availability
119
+ m = self # self.model[-1] # Detect() module
120
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
121
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
122
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
123
+ a[-1].bias.data[:] = 1.0 # box
124
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
125
+
126
+
127
+ class DualDetect(nn.Module):
128
+ # YOLO Detect head for detection models
129
+ dynamic = False # force grid reconstruction
130
+ export = False # export mode
131
+ shape = None
132
+ anchors = torch.empty(0) # init
133
+ strides = torch.empty(0) # init
134
+
135
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
136
+ super().__init__()
137
+ self.nc = nc # number of classes
138
+ self.nl = len(ch) // 2 # number of detection layers
139
+ self.reg_max = 16
140
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
141
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
142
+ self.stride = torch.zeros(self.nl) # strides computed during build
143
+
144
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
145
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
146
+ self.cv2 = nn.ModuleList(
147
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
148
+ self.cv3 = nn.ModuleList(
149
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
150
+ self.cv4 = nn.ModuleList(
151
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:])
152
+ self.cv5 = nn.ModuleList(
153
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
154
+ self.dfl = DFL(self.reg_max)
155
+ self.dfl2 = DFL(self.reg_max)
156
+
157
+ def forward(self, x):
158
+ shape = x[0].shape # BCHW
159
+ d1 = []
160
+ d2 = []
161
+ for i in range(self.nl):
162
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
163
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
164
+ if self.training:
165
+ return [d1, d2]
166
+ elif self.dynamic or self.shape != shape:
167
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
168
+ self.shape = shape
169
+
170
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
171
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
172
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
173
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
174
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
175
+ return y if self.export else (y, [d1, d2])
176
+
177
+ def bias_init(self):
178
+ # Initialize Detect() biases, WARNING: requires stride availability
179
+ m = self # self.model[-1] # Detect() module
180
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
181
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
182
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
183
+ a[-1].bias.data[:] = 1.0 # box
184
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
185
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
186
+ a[-1].bias.data[:] = 1.0 # box
187
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
188
+
189
+
190
+ class DualDDetect(nn.Module):
191
+ # YOLO Detect head for detection models
192
+ dynamic = False # force grid reconstruction
193
+ export = False # export mode
194
+ shape = None
195
+ anchors = torch.empty(0) # init
196
+ strides = torch.empty(0) # init
197
+
198
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
199
+ super().__init__()
200
+ self.nc = nc # number of classes
201
+ self.nl = len(ch) // 2 # number of detection layers
202
+ self.reg_max = 16
203
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
204
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
205
+ self.stride = torch.zeros(self.nl) # strides computed during build
206
+
207
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), max((ch[0], min((self.nc * 2, 128)))) # channels
208
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
209
+ self.cv2 = nn.ModuleList(
210
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4), nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
211
+ self.cv3 = nn.ModuleList(
212
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
213
+ self.cv4 = nn.ModuleList(
214
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4), nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:])
215
+ self.cv5 = nn.ModuleList(
216
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:])
217
+ self.dfl = DFL(self.reg_max)
218
+ self.dfl2 = DFL(self.reg_max)
219
+
220
+ def forward(self, x):
221
+ shape = x[0].shape # BCHW
222
+ d1 = []
223
+ d2 = []
224
+ for i in range(self.nl):
225
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
226
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
227
+ if self.training:
228
+ return [d1, d2]
229
+ elif self.dynamic or self.shape != shape:
230
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
231
+ self.shape = shape
232
+
233
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
234
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
235
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
236
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
237
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1)]
238
+ return y if self.export else (y, [d1, d2])
239
+ #y = torch.cat((dbox2, cls2.sigmoid()), 1)
240
+ #return y if self.export else (y, d2)
241
+ #y1 = torch.cat((dbox, cls.sigmoid()), 1)
242
+ #y2 = torch.cat((dbox2, cls2.sigmoid()), 1)
243
+ #return [y1, y2] if self.export else [(y1, d1), (y2, d2)]
244
+ #return [y1, y2] if self.export else [(y1, y2), (d1, d2)]
245
+
246
+ def bias_init(self):
247
+ # Initialize Detect() biases, WARNING: requires stride availability
248
+ m = self # self.model[-1] # Detect() module
249
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
250
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
251
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
252
+ a[-1].bias.data[:] = 1.0 # box
253
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
254
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
255
+ a[-1].bias.data[:] = 1.0 # box
256
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
257
+
258
+
259
+ class TripleDetect(nn.Module):
260
+ # YOLO Detect head for detection models
261
+ dynamic = False # force grid reconstruction
262
+ export = False # export mode
263
+ shape = None
264
+ anchors = torch.empty(0) # init
265
+ strides = torch.empty(0) # init
266
+
267
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
268
+ super().__init__()
269
+ self.nc = nc # number of classes
270
+ self.nl = len(ch) // 3 # number of detection layers
271
+ self.reg_max = 16
272
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
273
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
274
+ self.stride = torch.zeros(self.nl) # strides computed during build
275
+
276
+ c2, c3 = max((ch[0] // 4, self.reg_max * 4, 16)), max((ch[0], min((self.nc * 2, 128)))) # channels
277
+ c4, c5 = max((ch[self.nl] // 4, self.reg_max * 4, 16)), max((ch[self.nl], min((self.nc * 2, 128)))) # channels
278
+ c6, c7 = max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
279
+ self.cv2 = nn.ModuleList(
280
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch[:self.nl])
281
+ self.cv3 = nn.ModuleList(
282
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
283
+ self.cv4 = nn.ModuleList(
284
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, 4 * self.reg_max, 1)) for x in ch[self.nl:self.nl*2])
285
+ self.cv5 = nn.ModuleList(
286
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
287
+ self.cv6 = nn.ModuleList(
288
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, 4 * self.reg_max, 1)) for x in ch[self.nl*2:self.nl*3])
289
+ self.cv7 = nn.ModuleList(
290
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
291
+ self.dfl = DFL(self.reg_max)
292
+ self.dfl2 = DFL(self.reg_max)
293
+ self.dfl3 = DFL(self.reg_max)
294
+
295
+ def forward(self, x):
296
+ shape = x[0].shape # BCHW
297
+ d1 = []
298
+ d2 = []
299
+ d3 = []
300
+ for i in range(self.nl):
301
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
302
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
303
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
304
+ if self.training:
305
+ return [d1, d2, d3]
306
+ elif self.dynamic or self.shape != shape:
307
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
308
+ self.shape = shape
309
+
310
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
311
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
312
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
313
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
314
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
315
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
316
+ y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
317
+ return y if self.export else (y, [d1, d2, d3])
318
+
319
+ def bias_init(self):
320
+ # Initialize Detect() biases, WARNING: requires stride availability
321
+ m = self # self.model[-1] # Detect() module
322
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
323
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
324
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
325
+ a[-1].bias.data[:] = 1.0 # box
326
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
327
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
328
+ a[-1].bias.data[:] = 1.0 # box
329
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
330
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
331
+ a[-1].bias.data[:] = 1.0 # box
332
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
333
+
334
+
335
+ class TripleDDetect(nn.Module):
336
+ # YOLO Detect head for detection models
337
+ dynamic = False # force grid reconstruction
338
+ export = False # export mode
339
+ shape = None
340
+ anchors = torch.empty(0) # init
341
+ strides = torch.empty(0) # init
342
+
343
+ def __init__(self, nc=80, ch=(), inplace=True): # detection layer
344
+ super().__init__()
345
+ self.nc = nc # number of classes
346
+ self.nl = len(ch) // 3 # number of detection layers
347
+ self.reg_max = 16
348
+ self.no = nc + self.reg_max * 4 # number of outputs per anchor
349
+ self.inplace = inplace # use inplace ops (e.g. slice assignment)
350
+ self.stride = torch.zeros(self.nl) # strides computed during build
351
+
352
+ c2, c3 = make_divisible(max((ch[0] // 4, self.reg_max * 4, 16)), 4), \
353
+ max((ch[0], min((self.nc * 2, 128)))) # channels
354
+ c4, c5 = make_divisible(max((ch[self.nl] // 4, self.reg_max * 4, 16)), 4), \
355
+ max((ch[self.nl], min((self.nc * 2, 128)))) # channels
356
+ c6, c7 = make_divisible(max((ch[self.nl * 2] // 4, self.reg_max * 4, 16)), 4), \
357
+ max((ch[self.nl * 2], min((self.nc * 2, 128)))) # channels
358
+ self.cv2 = nn.ModuleList(
359
+ nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3, g=4),
360
+ nn.Conv2d(c2, 4 * self.reg_max, 1, groups=4)) for x in ch[:self.nl])
361
+ self.cv3 = nn.ModuleList(
362
+ nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch[:self.nl])
363
+ self.cv4 = nn.ModuleList(
364
+ nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3, g=4),
365
+ nn.Conv2d(c4, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl:self.nl*2])
366
+ self.cv5 = nn.ModuleList(
367
+ nn.Sequential(Conv(x, c5, 3), Conv(c5, c5, 3), nn.Conv2d(c5, self.nc, 1)) for x in ch[self.nl:self.nl*2])
368
+ self.cv6 = nn.ModuleList(
369
+ nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3, g=4),
370
+ nn.Conv2d(c6, 4 * self.reg_max, 1, groups=4)) for x in ch[self.nl*2:self.nl*3])
371
+ self.cv7 = nn.ModuleList(
372
+ nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nc, 1)) for x in ch[self.nl*2:self.nl*3])
373
+ self.dfl = DFL(self.reg_max)
374
+ self.dfl2 = DFL(self.reg_max)
375
+ self.dfl3 = DFL(self.reg_max)
376
+
377
+ def forward(self, x):
378
+ shape = x[0].shape # BCHW
379
+ d1 = []
380
+ d2 = []
381
+ d3 = []
382
+ for i in range(self.nl):
383
+ d1.append(torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1))
384
+ d2.append(torch.cat((self.cv4[i](x[self.nl+i]), self.cv5[i](x[self.nl+i])), 1))
385
+ d3.append(torch.cat((self.cv6[i](x[self.nl*2+i]), self.cv7[i](x[self.nl*2+i])), 1))
386
+ if self.training:
387
+ return [d1, d2, d3]
388
+ elif self.dynamic or self.shape != shape:
389
+ self.anchors, self.strides = (d1.transpose(0, 1) for d1 in make_anchors(d1, self.stride, 0.5))
390
+ self.shape = shape
391
+
392
+ box, cls = torch.cat([di.view(shape[0], self.no, -1) for di in d1], 2).split((self.reg_max * 4, self.nc), 1)
393
+ dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
394
+ box2, cls2 = torch.cat([di.view(shape[0], self.no, -1) for di in d2], 2).split((self.reg_max * 4, self.nc), 1)
395
+ dbox2 = dist2bbox(self.dfl2(box2), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
396
+ box3, cls3 = torch.cat([di.view(shape[0], self.no, -1) for di in d3], 2).split((self.reg_max * 4, self.nc), 1)
397
+ dbox3 = dist2bbox(self.dfl3(box3), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
398
+ #y = [torch.cat((dbox, cls.sigmoid()), 1), torch.cat((dbox2, cls2.sigmoid()), 1), torch.cat((dbox3, cls3.sigmoid()), 1)]
399
+ #return y if self.export else (y, [d1, d2, d3])
400
+ y = torch.cat((dbox3, cls3.sigmoid()), 1)
401
+ return y if self.export else (y, d3)
402
+
403
+ def bias_init(self):
404
+ # Initialize Detect() biases, WARNING: requires stride availability
405
+ m = self # self.model[-1] # Detect() module
406
+ # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
407
+ # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
408
+ for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
409
+ a[-1].bias.data[:] = 1.0 # box
410
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
411
+ for a, b, s in zip(m.cv4, m.cv5, m.stride): # from
412
+ a[-1].bias.data[:] = 1.0 # box
413
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
414
+ for a, b, s in zip(m.cv6, m.cv7, m.stride): # from
415
+ a[-1].bias.data[:] = 1.0 # box
416
+ b[-1].bias.data[:m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (5 objects and 80 classes per 640 image)
417
+
418
+
419
+ class Segment(Detect):
420
+ # YOLO Segment head for segmentation models
421
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
422
+ super().__init__(nc, ch, inplace)
423
+ self.nm = nm # number of masks
424
+ self.npr = npr # number of protos
425
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
426
+ self.detect = Detect.forward
427
+
428
+ c4 = max(ch[0] // 4, self.nm)
429
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
430
+
431
+ def forward(self, x):
432
+ p = self.proto(x[0])
433
+ bs = p.shape[0]
434
+
435
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
436
+ x = self.detect(self, x)
437
+ if self.training:
438
+ return x, mc, p
439
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
440
+
441
+
442
+ class DSegment(DDetect):
443
+ # YOLO Segment head for segmentation models
444
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
445
+ super().__init__(nc, ch[:-1], inplace)
446
+ self.nl = len(ch)-1
447
+ self.nm = nm # number of masks
448
+ self.npr = npr # number of protos
449
+ self.proto = Conv(ch[-1], self.nm, 1) # protos
450
+ self.detect = DDetect.forward
451
+
452
+ c4 = max(ch[0] // 4, self.nm)
453
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch[:-1])
454
+
455
+ def forward(self, x):
456
+ p = self.proto(x[-1])
457
+ bs = p.shape[0]
458
+
459
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
460
+ x = self.detect(self, x[:-1])
461
+ if self.training:
462
+ return x, mc, p
463
+ return (torch.cat([x, mc], 1), p) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p))
464
+
465
+
466
+ class DualDSegment(DualDDetect):
467
+ # YOLO Segment head for segmentation models
468
+ def __init__(self, nc=80, nm=32, npr=256, ch=(), inplace=True):
469
+ super().__init__(nc, ch[:-2], inplace)
470
+ self.nl = (len(ch)-2) // 2
471
+ self.nm = nm # number of masks
472
+ self.npr = npr # number of protos
473
+ self.proto = Conv(ch[-2], self.nm, 1) # protos
474
+ self.proto2 = Conv(ch[-1], self.nm, 1) # protos
475
+ self.detect = DualDDetect.forward
476
+
477
+ c6 = max(ch[0] // 4, self.nm)
478
+ c7 = max(ch[self.nl] // 4, self.nm)
479
+ self.cv6 = nn.ModuleList(nn.Sequential(Conv(x, c6, 3), Conv(c6, c6, 3), nn.Conv2d(c6, self.nm, 1)) for x in ch[:self.nl])
480
+ self.cv7 = nn.ModuleList(nn.Sequential(Conv(x, c7, 3), Conv(c7, c7, 3), nn.Conv2d(c7, self.nm, 1)) for x in ch[self.nl:self.nl*2])
481
+
482
+ def forward(self, x):
483
+ p = [self.proto(x[-2]), self.proto2(x[-1])]
484
+ bs = p[0].shape[0]
485
+
486
+ mc = [torch.cat([self.cv6[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2),
487
+ torch.cat([self.cv7[i](x[self.nl+i]).view(bs, self.nm, -1) for i in range(self.nl)], 2)] # mask coefficients
488
+ d = self.detect(self, x[:-2])
489
+ if self.training:
490
+ return d, mc, p
491
+ return (torch.cat([d[0][1], mc[1]], 1), (d[1][1], mc[1], p[1]))
492
+
493
+
494
+ class Panoptic(Detect):
495
+ # YOLO Panoptic head for panoptic segmentation models
496
+ def __init__(self, nc=80, sem_nc=93, nm=32, npr=256, ch=(), inplace=True):
497
+ super().__init__(nc, ch, inplace)
498
+ self.sem_nc = sem_nc
499
+ self.nm = nm # number of masks
500
+ self.npr = npr # number of protos
501
+ self.proto = Proto(ch[0], self.npr, self.nm) # protos
502
+ self.uconv = UConv(ch[0], ch[0]//4, self.sem_nc+self.nc)
503
+ self.detect = Detect.forward
504
+
505
+ c4 = max(ch[0] // 4, self.nm)
506
+ self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
507
+
508
+
509
+ def forward(self, x):
510
+ p = self.proto(x[0])
511
+ s = self.uconv(x[0])
512
+ bs = p.shape[0]
513
+
514
+ mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
515
+ x = self.detect(self, x)
516
+ if self.training:
517
+ return x, mc, p, s
518
+ return (torch.cat([x, mc], 1), p, s) if self.export else (torch.cat([x[0], mc], 1), (x[1], mc, p, s))
519
+
520
+
521
+ class BaseModel(nn.Module):
522
+ # YOLO base model
523
+ def forward(self, x, profile=False, visualize=False):
524
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
525
+
526
+ def _forward_once(self, x, profile=False, visualize=False):
527
+ y, dt = [], [] # outputs
528
+ for m in self.model:
529
+ if m.f != -1: # if not from previous layer
530
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
531
+ if profile:
532
+ self._profile_one_layer(m, x, dt)
533
+ x = m(x) # run
534
+ y.append(x if m.i in self.save else None) # save output
535
+ if visualize:
536
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
537
+ return x
538
+
539
+ def _profile_one_layer(self, m, x, dt):
540
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
541
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
542
+ t = time_sync()
543
+ for _ in range(10):
544
+ m(x.copy() if c else x)
545
+ dt.append((time_sync() - t) * 100)
546
+ if m == self.model[0]:
547
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
548
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
549
+ if c:
550
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
551
+
552
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
553
+ LOGGER.info('Fusing layers... ')
554
+ for m in self.model.modules():
555
+ if isinstance(m, (RepConvN)) and hasattr(m, 'fuse_convs'):
556
+ m.fuse_convs()
557
+ m.forward = m.forward_fuse # update forward
558
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
559
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
560
+ delattr(m, 'bn') # remove batchnorm
561
+ m.forward = m.forward_fuse # update forward
562
+ self.info()
563
+ return self
564
+
565
+ def info(self, verbose=False, img_size=640): # print model information
566
+ model_info(self, verbose, img_size)
567
+
568
+ def _apply(self, fn):
569
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
570
+ self = super()._apply(fn)
571
+ m = self.model[-1] # Detect()
572
+ if isinstance(m, (Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, DSegment, DualDSegment, Panoptic)):
573
+ m.stride = fn(m.stride)
574
+ m.anchors = fn(m.anchors)
575
+ m.strides = fn(m.strides)
576
+ # m.grid = list(map(fn, m.grid))
577
+ return self
578
+
579
+
580
+ class DetectionModel(BaseModel):
581
+ # YOLO detection model
582
+ def __init__(self, cfg='yolo.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
583
+ super().__init__()
584
+ if isinstance(cfg, dict):
585
+ self.yaml = cfg # model dict
586
+ else: # is *.yaml
587
+ import yaml # for torch hub
588
+ self.yaml_file = Path(cfg).name
589
+ with open(cfg, encoding='ascii', errors='ignore') as f:
590
+ self.yaml = yaml.safe_load(f) # model dict
591
+
592
+ # Define model
593
+ ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
594
+ if nc and nc != self.yaml['nc']:
595
+ LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
596
+ self.yaml['nc'] = nc # override yaml value
597
+ if anchors:
598
+ LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
599
+ self.yaml['anchors'] = round(anchors) # override yaml value
600
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
601
+ self.names = [str(i) for i in range(self.yaml['nc'])] # default names
602
+ self.inplace = self.yaml.get('inplace', True)
603
+
604
+ # Build strides, anchors
605
+ m = self.model[-1] # Detect()
606
+ if isinstance(m, (Detect, DDetect, Segment, DSegment, Panoptic)):
607
+ s = 256 # 2x min stride
608
+ m.inplace = self.inplace
609
+ forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, DSegment, Panoptic)) else self.forward(x)
610
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
611
+ # check_anchor_order(m)
612
+ # m.anchors /= m.stride.view(-1, 1, 1)
613
+ self.stride = m.stride
614
+ m.bias_init() # only run once
615
+ if isinstance(m, (DualDetect, TripleDetect, DualDDetect, TripleDDetect, DualDSegment)):
616
+ s = 256 # 2x min stride
617
+ m.inplace = self.inplace
618
+ forward = lambda x: self.forward(x)[0][0] if isinstance(m, (DualDSegment)) else self.forward(x)[0]
619
+ m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
620
+ # check_anchor_order(m)
621
+ # m.anchors /= m.stride.view(-1, 1, 1)
622
+ self.stride = m.stride
623
+ m.bias_init() # only run once
624
+
625
+ # Init weights, biases
626
+ initialize_weights(self)
627
+ self.info()
628
+ LOGGER.info('')
629
+
630
+ def forward(self, x, augment=False, profile=False, visualize=False):
631
+ if augment:
632
+ return self._forward_augment(x) # augmented inference, None
633
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
634
+
635
+ def _forward_augment(self, x):
636
+ img_size = x.shape[-2:] # height, width
637
+ s = [1, 0.83, 0.67] # scales
638
+ f = [None, 3, None] # flips (2-ud, 3-lr)
639
+ y = [] # outputs
640
+ for si, fi in zip(s, f):
641
+ xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
642
+ yi = self._forward_once(xi)[0] # forward
643
+ # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
644
+ yi = self._descale_pred(yi, fi, si, img_size)
645
+ y.append(yi)
646
+ y = self._clip_augmented(y) # clip augmented tails
647
+ return torch.cat(y, 1), None # augmented inference, train
648
+
649
+ def _descale_pred(self, p, flips, scale, img_size):
650
+ # de-scale predictions following augmented inference (inverse operation)
651
+ if self.inplace:
652
+ p[..., :4] /= scale # de-scale
653
+ if flips == 2:
654
+ p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
655
+ elif flips == 3:
656
+ p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
657
+ else:
658
+ x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
659
+ if flips == 2:
660
+ y = img_size[0] - y # de-flip ud
661
+ elif flips == 3:
662
+ x = img_size[1] - x # de-flip lr
663
+ p = torch.cat((x, y, wh, p[..., 4:]), -1)
664
+ return p
665
+
666
+ def _clip_augmented(self, y):
667
+ # Clip YOLO augmented inference tails
668
+ nl = self.model[-1].nl # number of detection layers (P3-P5)
669
+ g = sum(4 ** x for x in range(nl)) # grid points
670
+ e = 1 # exclude layer count
671
+ i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
672
+ y[0] = y[0][:, :-i] # large
673
+ i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
674
+ y[-1] = y[-1][:, i:] # small
675
+ return y
676
+
677
+
678
+ Model = DetectionModel # retain YOLO 'Model' class for backwards compatibility
679
+
680
+
681
+ class SegmentationModel(DetectionModel):
682
+ # YOLO segmentation model
683
+ def __init__(self, cfg='yolo-seg.yaml', ch=3, nc=None, anchors=None):
684
+ super().__init__(cfg, ch, nc, anchors)
685
+
686
+
687
+ class ClassificationModel(BaseModel):
688
+ # YOLO classification model
689
+ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
690
+ super().__init__()
691
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
692
+
693
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
694
+ # Create a YOLO classification model from a YOLO detection model
695
+ if isinstance(model, DetectMultiBackend):
696
+ model = model.model # unwrap DetectMultiBackend
697
+ model.model = model.model[:cutoff] # backbone
698
+ m = model.model[-1] # last layer
699
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
700
+ c = Classify(ch, nc) # Classify()
701
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
702
+ model.model[-1] = c # replace
703
+ self.model = model.model
704
+ self.stride = model.stride
705
+ self.save = []
706
+ self.nc = nc
707
+
708
+ def _from_yaml(self, cfg):
709
+ # Create a YOLO classification model from a *.yaml file
710
+ self.model = None
711
+
712
+
713
+ def parse_model(d, ch): # model_dict, input_channels(3)
714
+ # Parse a YOLO model.yaml dictionary
715
+ LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
716
+ anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
717
+ if act:
718
+ Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
719
+ RepConvN.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
720
+ LOGGER.info(f"{colorstr('activation:')} {act}") # print
721
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
722
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
723
+
724
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
725
+ for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
726
+ m = eval(m) if isinstance(m, str) else m # eval strings
727
+ for j, a in enumerate(args):
728
+ with contextlib.suppress(NameError):
729
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
730
+
731
+ n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
732
+ if m in {
733
+ Conv, AConv, ConvTranspose,
734
+ Bottleneck, SPP, SPPF, DWConv, BottleneckCSP, nn.ConvTranspose2d, DWConvTranspose2d, SPPCSPC, ADown,
735
+ ELAN1, RepNCSPELAN4, SPPELAN}:
736
+ c1, c2 = ch[f], args[0]
737
+ if c2 != no: # if not output
738
+ c2 = make_divisible(c2 * gw, 8)
739
+
740
+ args = [c1, c2, *args[1:]]
741
+ if m in {BottleneckCSP, SPPCSPC}:
742
+ args.insert(2, n) # number of repeats
743
+ n = 1
744
+ elif m is nn.BatchNorm2d:
745
+ args = [ch[f]]
746
+ elif m is Concat:
747
+ c2 = sum(ch[x] for x in f)
748
+ elif m is Shortcut:
749
+ c2 = ch[f[0]]
750
+ elif m is ReOrg:
751
+ c2 = ch[f] * 4
752
+ elif m is CBLinear:
753
+ c2 = args[0]
754
+ c1 = ch[f]
755
+ args = [c1, c2, *args[1:]]
756
+ elif m is CBFuse:
757
+ c2 = ch[f[-1]]
758
+ # TODO: channel, gw, gd
759
+ elif m in {Detect, DualDetect, TripleDetect, DDetect, DualDDetect, TripleDDetect, Segment, DSegment, DualDSegment, Panoptic}:
760
+ args.append([ch[x] for x in f])
761
+ # if isinstance(args[1], int): # number of anchors
762
+ # args[1] = [list(range(args[1] * 2))] * len(f)
763
+ if m in {Segment, DSegment, DualDSegment, Panoptic}:
764
+ args[2] = make_divisible(args[2] * gw, 8)
765
+ elif m is Contract:
766
+ c2 = ch[f] * args[0] ** 2
767
+ elif m is Expand:
768
+ c2 = ch[f] // args[0] ** 2
769
+ else:
770
+ c2 = ch[f]
771
+
772
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
773
+ t = str(m)[8:-2].replace('__main__.', '') # module type
774
+ np = sum(x.numel() for x in m_.parameters()) # number params
775
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
776
+ LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
777
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
778
+ layers.append(m_)
779
+ if i == 0:
780
+ ch = []
781
+ ch.append(c2)
782
+ return nn.Sequential(*layers), sorted(save)
783
+
784
+
785
+ if __name__ == '__main__':
786
+ parser = argparse.ArgumentParser()
787
+ parser.add_argument('--cfg', type=str, default='yolo.yaml', help='model.yaml')
788
+ parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
789
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
790
+ parser.add_argument('--profile', action='store_true', help='profile model speed')
791
+ parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
792
+ parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
793
+ opt = parser.parse_args()
794
+ opt.cfg = check_yaml(opt.cfg) # check YAML
795
+ print_args(vars(opt))
796
+ device = select_device(opt.device)
797
+
798
+ # Create model
799
+ im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
800
+ model = Model(opt.cfg).to(device)
801
+ model.eval()
802
+
803
+ # Options
804
+ if opt.line_profile: # profile layer by layer
805
+ model(im, profile=True)
806
+
807
+ elif opt.profile: # profile forward-backward
808
+ results = profile(input=im, ops=[model], n=3)
809
+
810
+ elif opt.test: # test all models
811
+ for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
812
+ try:
813
+ _ = Model(cfg)
814
+ except Exception as e:
815
+ print(f'Error in {cfg}: {e}')
816
+
817
+ else: # report fused model summary
818
+ model.fuse()
yolov9/panoptic/predict.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ FILE = Path(__file__).resolve()
10
+ ROOT = FILE.parents[1] # YOLO root directory
11
+ if str(ROOT) not in sys.path:
12
+ sys.path.append(str(ROOT)) # add ROOT to PATH
13
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14
+
15
+ from models.common import DetectMultiBackend
16
+ from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
17
+ from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
18
+ increment_path, non_max_suppression, print_args, scale_boxes, scale_segments,
19
+ strip_optimizer, xyxy2xywh)
20
+ from utils.plots import Annotator, colors, save_one_box
21
+ from utils.segment.general import masks2segments, process_mask
22
+ from utils.torch_utils import select_device, smart_inference_mode
23
+
24
+
25
+ @smart_inference_mode()
26
+ def run(
27
+ weights=ROOT / 'yolo-pan.pt', # model.pt path(s)
28
+ source=ROOT / 'data/images', # file/dir/URL/glob/screen/0(webcam)
29
+ data=ROOT / 'data/coco128.yaml', # dataset.yaml path
30
+ imgsz=(640, 640), # inference size (height, width)
31
+ conf_thres=0.25, # confidence threshold
32
+ iou_thres=0.45, # NMS IOU threshold
33
+ max_det=1000, # maximum detections per image
34
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
35
+ view_img=False, # show results
36
+ save_txt=False, # save results to *.txt
37
+ save_conf=False, # save confidences in --save-txt labels
38
+ save_crop=False, # save cropped prediction boxes
39
+ nosave=False, # do not save images/videos
40
+ classes=None, # filter by class: --class 0, or --class 0 2 3
41
+ agnostic_nms=False, # class-agnostic NMS
42
+ augment=False, # augmented inference
43
+ visualize=False, # visualize features
44
+ update=False, # update all models
45
+ project=ROOT / 'runs/predict-seg', # save results to project/name
46
+ name='exp', # save results to project/name
47
+ exist_ok=False, # existing project/name ok, do not increment
48
+ line_thickness=3, # bounding box thickness (pixels)
49
+ hide_labels=False, # hide labels
50
+ hide_conf=False, # hide confidences
51
+ half=False, # use FP16 half-precision inference
52
+ dnn=False, # use OpenCV DNN for ONNX inference
53
+ vid_stride=1, # video frame-rate stride
54
+ retina_masks=False,
55
+ ):
56
+ source = str(source)
57
+ save_img = not nosave and not source.endswith('.txt') # save inference images
58
+ is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
59
+ is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://'))
60
+ webcam = source.isnumeric() or source.endswith('.txt') or (is_url and not is_file)
61
+ screenshot = source.lower().startswith('screen')
62
+ if is_url and is_file:
63
+ source = check_file(source) # download
64
+
65
+ # Directories
66
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
67
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
68
+
69
+ # Load model
70
+ device = select_device(device)
71
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
72
+ stride, names, pt = model.stride, model.names, model.pt
73
+ imgsz = check_img_size(imgsz, s=stride) # check image size
74
+
75
+ # Dataloader
76
+ bs = 1 # batch_size
77
+ if webcam:
78
+ view_img = check_imshow(warn=True)
79
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
80
+ bs = len(dataset)
81
+ elif screenshot:
82
+ dataset = LoadScreenshots(source, img_size=imgsz, stride=stride, auto=pt)
83
+ else:
84
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
85
+ vid_path, vid_writer = [None] * bs, [None] * bs
86
+
87
+ # Run inference
88
+ model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup
89
+ seen, windows, dt = 0, [], (Profile(), Profile(), Profile())
90
+ for path, im, im0s, vid_cap, s in dataset:
91
+ with dt[0]:
92
+ im = torch.from_numpy(im).to(model.device)
93
+ im = im.half() if model.fp16 else im.float() # uint8 to fp16/32
94
+ im /= 255 # 0 - 255 to 0.0 - 1.0
95
+ if len(im.shape) == 3:
96
+ im = im[None] # expand for batch dim
97
+
98
+ # Inference
99
+ with dt[1]:
100
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
101
+ pred, proto = model(im, augment=augment, visualize=visualize)[:2]
102
+
103
+ # NMS
104
+ with dt[2]:
105
+ pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det, nm=32)
106
+
107
+ # Second-stage classifier (optional)
108
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
109
+
110
+ # Process predictions
111
+ for i, det in enumerate(pred): # per image
112
+ seen += 1
113
+ if webcam: # batch_size >= 1
114
+ p, im0, frame = path[i], im0s[i].copy(), dataset.count
115
+ s += f'{i}: '
116
+ else:
117
+ p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
118
+
119
+ p = Path(p) # to Path
120
+ save_path = str(save_dir / p.name) # im.jpg
121
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
122
+ s += '%gx%g ' % im.shape[2:] # print string
123
+ imc = im0.copy() if save_crop else im0 # for save_crop
124
+ annotator = Annotator(im0, line_width=line_thickness, example=str(names))
125
+ if len(det):
126
+ masks = process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True) # HWC
127
+ det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size
128
+
129
+ # Segments
130
+ if save_txt:
131
+ segments = reversed(masks2segments(masks))
132
+ segments = [scale_segments(im.shape[2:], x, im0.shape, normalize=True) for x in segments]
133
+
134
+ # Print results
135
+ for c in det[:, 5].unique():
136
+ n = (det[:, 5] == c).sum() # detections per class
137
+ s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
138
+
139
+ # Mask plotting
140
+ annotator.masks(masks,
141
+ colors=[colors(x, True) for x in det[:, 5]],
142
+ im_gpu=None if retina_masks else im[i])
143
+
144
+ # Write results
145
+ for j, (*xyxy, conf, cls) in enumerate(reversed(det[:, :6])):
146
+ if save_txt: # Write to file
147
+ segj = segments[j].reshape(-1) # (n,2) to (n*2)
148
+ line = (cls, *segj, conf) if save_conf else (cls, *segj) # label format
149
+ with open(f'{txt_path}.txt', 'a') as f:
150
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
151
+
152
+ if save_img or save_crop or view_img: # Add bbox to image
153
+ c = int(cls) # integer class
154
+ label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f}')
155
+ annotator.box_label(xyxy, label, color=colors(c, True))
156
+ # annotator.draw.polygon(segments[j], outline=colors(c, True), width=3)
157
+ if save_crop:
158
+ save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
159
+
160
+ # Stream results
161
+ im0 = annotator.result()
162
+ if view_img:
163
+ if platform.system() == 'Linux' and p not in windows:
164
+ windows.append(p)
165
+ cv2.namedWindow(str(p), cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
166
+ cv2.resizeWindow(str(p), im0.shape[1], im0.shape[0])
167
+ cv2.imshow(str(p), im0)
168
+ if cv2.waitKey(1) == ord('q'): # 1 millisecond
169
+ exit()
170
+
171
+ # Save results (image with detections)
172
+ if save_img:
173
+ if dataset.mode == 'image':
174
+ cv2.imwrite(save_path, im0)
175
+ else: # 'video' or 'stream'
176
+ if vid_path[i] != save_path: # new video
177
+ vid_path[i] = save_path
178
+ if isinstance(vid_writer[i], cv2.VideoWriter):
179
+ vid_writer[i].release() # release previous video writer
180
+ if vid_cap: # video
181
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
182
+ w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
183
+ h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
184
+ else: # stream
185
+ fps, w, h = 30, im0.shape[1], im0.shape[0]
186
+ save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
187
+ vid_writer[i] = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
188
+ vid_writer[i].write(im0)
189
+
190
+ # Print time (inference-only)
191
+ LOGGER.info(f"{s}{'' if len(det) else '(no detections), '}{dt[1].dt * 1E3:.1f}ms")
192
+
193
+ # Print results
194
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
195
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {(1, 3, *imgsz)}' % t)
196
+ if save_txt or save_img:
197
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
198
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
199
+ if update:
200
+ strip_optimizer(weights[0]) # update model (to fix SourceChangeWarning)
201
+
202
+
203
+ def parse_opt():
204
+ parser = argparse.ArgumentParser()
205
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo-pan.pt', help='model path(s)')
206
+ parser.add_argument('--source', type=str, default=ROOT / 'data/images', help='file/dir/URL/glob/screen/0(webcam)')
207
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='(optional) dataset.yaml path')
208
+ parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640], help='inference size h,w')
209
+ parser.add_argument('--conf-thres', type=float, default=0.25, help='confidence threshold')
210
+ parser.add_argument('--iou-thres', type=float, default=0.45, help='NMS IoU threshold')
211
+ parser.add_argument('--max-det', type=int, default=1000, help='maximum detections per image')
212
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
213
+ parser.add_argument('--view-img', action='store_true', help='show results')
214
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
215
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
216
+ parser.add_argument('--save-crop', action='store_true', help='save cropped prediction boxes')
217
+ parser.add_argument('--nosave', action='store_true', help='do not save images/videos')
218
+ parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --classes 0, or --classes 0 2 3')
219
+ parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
220
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
221
+ parser.add_argument('--visualize', action='store_true', help='visualize features')
222
+ parser.add_argument('--update', action='store_true', help='update all models')
223
+ parser.add_argument('--project', default=ROOT / 'runs/predict-seg', help='save results to project/name')
224
+ parser.add_argument('--name', default='exp', help='save results to project/name')
225
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
226
+ parser.add_argument('--line-thickness', default=3, type=int, help='bounding box thickness (pixels)')
227
+ parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels')
228
+ parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences')
229
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
230
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
231
+ parser.add_argument('--vid-stride', type=int, default=1, help='video frame-rate stride')
232
+ parser.add_argument('--retina-masks', action='store_true', help='whether to plot masks in native resolution')
233
+ opt = parser.parse_args()
234
+ opt.imgsz *= 2 if len(opt.imgsz) == 1 else 1 # expand
235
+ print_args(vars(opt))
236
+ return opt
237
+
238
+
239
+ def main(opt):
240
+ check_requirements(exclude=('tensorboard', 'thop'))
241
+ run(**vars(opt))
242
+
243
+
244
+ if __name__ == "__main__":
245
+ opt = parse_opt()
246
+ main(opt)
yolov9/panoptic/train.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import random
5
+ import sys
6
+ import time
7
+ from copy import deepcopy
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import yaml
16
+ from torch.optim import lr_scheduler
17
+ from tqdm import tqdm
18
+
19
+ FILE = Path(__file__).resolve()
20
+ ROOT = FILE.parents[1] # YOLO root directory
21
+ if str(ROOT) not in sys.path:
22
+ sys.path.append(str(ROOT)) # add ROOT to PATH
23
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
24
+
25
+ import panoptic.val as validate # for end-of-epoch mAP
26
+ from models.experimental import attempt_load
27
+ from models.yolo import SegmentationModel
28
+ from utils.autoanchor import check_anchors
29
+ from utils.autobatch import check_train_batch_size
30
+ from utils.callbacks import Callbacks
31
+ from utils.downloads import attempt_download, is_url
32
+ from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
33
+ check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
34
+ get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
35
+ labels_to_image_weights, one_cycle, one_flat_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
36
+ from utils.loggers import GenericLogger
37
+ from utils.plots import plot_evolve, plot_labels
38
+ from utils.panoptic.dataloaders import create_dataloader
39
+ from utils.panoptic.loss_tal import ComputeLoss
40
+ from utils.panoptic.metrics import KEYS, fitness
41
+ from utils.panoptic.plots import plot_images_and_masks, plot_results_with_masks
42
+ from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
43
+ smart_resume, torch_distributed_zero_first)
44
+
45
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
46
+ RANK = int(os.getenv('RANK', -1))
47
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
48
+ GIT_INFO = None#check_git_info()
49
+
50
+
51
+ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
52
+ save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, mask_ratio = \
53
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
54
+ opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.mask_ratio
55
+ # callbacks.run('on_pretrain_routine_start')
56
+
57
+ # Directories
58
+ w = save_dir / 'weights' # weights dir
59
+ (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
60
+ last, best = w / 'last.pt', w / 'best.pt'
61
+
62
+ # Hyperparameters
63
+ if isinstance(hyp, str):
64
+ with open(hyp, errors='ignore') as f:
65
+ hyp = yaml.safe_load(f) # load hyps dict
66
+ LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
67
+ opt.hyp = hyp.copy() # for saving hyps to checkpoints
68
+
69
+ # Save run settings
70
+ if not evolve:
71
+ yaml_save(save_dir / 'hyp.yaml', hyp)
72
+ yaml_save(save_dir / 'opt.yaml', vars(opt))
73
+
74
+ # Loggers
75
+ data_dict = None
76
+ if RANK in {-1, 0}:
77
+ logger = GenericLogger(opt=opt, console_logger=LOGGER)
78
+
79
+ # Config
80
+ plots = not evolve and not opt.noplots # create plots
81
+ overlap = not opt.no_overlap
82
+ cuda = device.type != 'cpu'
83
+ init_seeds(opt.seed + 1 + RANK, deterministic=True)
84
+ with torch_distributed_zero_first(LOCAL_RANK):
85
+ data_dict = data_dict or check_dataset(data) # check if None
86
+ train_path, val_path = data_dict['train'], data_dict['val']
87
+ nc = 1 if single_cls else int(data_dict['nc']) # number of classes
88
+ names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
89
+ #is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
90
+ is_coco = isinstance(val_path, str) and val_path.endswith('val2017.txt') # COCO dataset
91
+
92
+ # Model
93
+ check_suffix(weights, '.pt') # check weights
94
+ pretrained = weights.endswith('.pt')
95
+ if pretrained:
96
+ with torch_distributed_zero_first(LOCAL_RANK):
97
+ weights = attempt_download(weights) # download if not found locally
98
+ ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
99
+ model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device)
100
+ exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
101
+ csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
102
+ csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
103
+ model.load_state_dict(csd, strict=False) # load
104
+ LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
105
+ else:
106
+ model = SegmentationModel(cfg, ch=3, nc=nc).to(device) # create
107
+ amp = check_amp(model) # check AMP
108
+
109
+ # Freeze
110
+ freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
111
+ for k, v in model.named_parameters():
112
+ #v.requires_grad = True # train all layers
113
+ # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
114
+ if any(x in k for x in freeze):
115
+ LOGGER.info(f'freezing {k}')
116
+ v.requires_grad = False
117
+
118
+ # Image size
119
+ gs = max(int(model.stride.max()), 32) # grid size (max stride)
120
+ imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
121
+
122
+ # Batch size
123
+ if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
124
+ batch_size = check_train_batch_size(model, imgsz, amp)
125
+ logger.update_params({"batch_size": batch_size})
126
+ # loggers.on_params_update({"batch_size": batch_size})
127
+
128
+ # Optimizer
129
+ nbs = 64 # nominal batch size
130
+ accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
131
+ hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
132
+ optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
133
+
134
+ # Scheduler
135
+ if opt.cos_lr:
136
+ lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
137
+ elif opt.flat_cos_lr:
138
+ lf = one_flat_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
139
+ elif opt.fixed_lr:
140
+ lf = lambda x: 1.0
141
+ elif opt.poly_lr:
142
+ power = 0.9
143
+ lf = lambda x: ((1 - (x / epochs)) ** power) * (1.0 - hyp['lrf']) + hyp['lrf']
144
+ else:
145
+ lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
146
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
147
+
148
+ # EMA
149
+ ema = ModelEMA(model) if RANK in {-1, 0} else None
150
+
151
+ # Resume
152
+ best_fitness, start_epoch = 0.0, 0
153
+ if pretrained:
154
+ if resume:
155
+ best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
156
+ del ckpt, csd
157
+
158
+ # DP mode
159
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
160
+ LOGGER.warning('WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.')
161
+ model = torch.nn.DataParallel(model)
162
+
163
+ # SyncBatchNorm
164
+ if opt.sync_bn and cuda and RANK != -1:
165
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
166
+ LOGGER.info('Using SyncBatchNorm()')
167
+
168
+ # Trainloader
169
+ train_loader, dataset = create_dataloader(
170
+ train_path,
171
+ imgsz,
172
+ batch_size // WORLD_SIZE,
173
+ gs,
174
+ single_cls,
175
+ hyp=hyp,
176
+ augment=True,
177
+ cache=None if opt.cache == 'val' else opt.cache,
178
+ rect=opt.rect,
179
+ rank=LOCAL_RANK,
180
+ workers=workers,
181
+ image_weights=opt.image_weights,
182
+ close_mosaic=opt.close_mosaic != 0,
183
+ quad=opt.quad,
184
+ prefix=colorstr('train: '),
185
+ shuffle=True,
186
+ mask_downsample_ratio=mask_ratio,
187
+ overlap_mask=overlap,
188
+ )
189
+ labels = np.concatenate(dataset.labels, 0)
190
+ mlc = int(labels[:, 0].max()) # max label class
191
+ assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
192
+
193
+ # Process 0
194
+ if RANK in {-1, 0}:
195
+ val_loader = create_dataloader(val_path,
196
+ imgsz,
197
+ batch_size // WORLD_SIZE * 2,
198
+ gs,
199
+ single_cls,
200
+ hyp=hyp,
201
+ cache=None if noval else opt.cache,
202
+ rect=True,
203
+ rank=-1,
204
+ workers=workers * 2,
205
+ pad=0.5,
206
+ mask_downsample_ratio=mask_ratio,
207
+ overlap_mask=overlap,
208
+ prefix=colorstr('val: '))[0]
209
+
210
+ if not resume:
211
+ #if not opt.noautoanchor:
212
+ # check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
213
+ model.half().float() # pre-reduce anchor precision
214
+
215
+ if plots:
216
+ plot_labels(labels, names, save_dir)
217
+ # callbacks.run('on_pretrain_routine_end', labels, names)
218
+
219
+ # DDP mode
220
+ if cuda and RANK != -1:
221
+ model = smart_DDP(model)
222
+
223
+ # Model attributes
224
+ nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
225
+ #hyp['box'] *= 3 / nl # scale to layers
226
+ #hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
227
+ #hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
228
+ hyp['label_smoothing'] = opt.label_smoothing
229
+ model.nc = nc # attach number of classes to model
230
+ model.hyp = hyp # attach hyperparameters to model
231
+ model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
232
+ model.names = names
233
+
234
+ # Start training
235
+ t0 = time.time()
236
+ nb = len(train_loader) # number of batches
237
+ nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
238
+ # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
239
+ last_opt_step = -1
240
+ maps = np.zeros(nc) # mAP per class
241
+ results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
242
+ scheduler.last_epoch = start_epoch - 1 # do not move
243
+ scaler = torch.cuda.amp.GradScaler(enabled=amp)
244
+ stopper, stop = EarlyStopping(patience=opt.patience), False
245
+ compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
246
+ # callbacks.run('on_train_start')
247
+ LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
248
+ f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
249
+ f"Logging results to {colorstr('bold', save_dir)}\n"
250
+ f'Starting training for {epochs} epochs...')
251
+ for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
252
+ # callbacks.run('on_train_epoch_start')
253
+ model.train()
254
+
255
+ # Update image weights (optional, single-GPU only)
256
+ if opt.image_weights:
257
+ cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
258
+ iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
259
+ dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
260
+ if epoch == (epochs - opt.close_mosaic):
261
+ LOGGER.info("Closing dataloader mosaic")
262
+ dataset.mosaic = False
263
+
264
+ # Update mosaic border (optional)
265
+ # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
266
+ # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
267
+
268
+ mloss = torch.zeros(6, device=device) # mean losses
269
+ if RANK != -1:
270
+ train_loader.sampler.set_epoch(epoch)
271
+ pbar = enumerate(train_loader)
272
+ LOGGER.info(('\n' + '%11s' * 10) %
273
+ ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'cls_loss', 'dfl_loss', 'fcl_loss', 'dic_loss', 'Instances', 'Size'))
274
+ if RANK in {-1, 0}:
275
+ pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
276
+ optimizer.zero_grad()
277
+ for i, (imgs, targets, paths, _, masks, semasks) in pbar: # batch ------------------------------------------------------
278
+ # callbacks.run('on_train_batch_start')
279
+ #print(imgs.shape)
280
+ #print(semasks.shape)
281
+ #print(masks.shape)
282
+ ni = i + nb * epoch # number integrated batches (since train start)
283
+ imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
284
+
285
+ # Warmup
286
+ if ni <= nw:
287
+ xi = [0, nw] # x interp
288
+ # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
289
+ accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
290
+ for j, x in enumerate(optimizer.param_groups):
291
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
292
+ x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
293
+ if 'momentum' in x:
294
+ x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
295
+
296
+ # Multi-scale
297
+ if opt.multi_scale:
298
+ sz = random.randrange(imgsz * 0.5, imgsz * 1.5 + gs) // gs * gs # size
299
+ sf = sz / max(imgs.shape[2:]) # scale factor
300
+ if sf != 1:
301
+ ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
302
+ imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
303
+
304
+ # Forward
305
+ with torch.cuda.amp.autocast(amp):
306
+ pred = model(imgs) # forward
307
+ loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float(),
308
+ semasks=semasks.to(device).float())
309
+ if RANK != -1:
310
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
311
+ if opt.quad:
312
+ loss *= 4.
313
+
314
+ # Backward
315
+ torch.use_deterministic_algorithms(False)
316
+ scaler.scale(loss).backward()
317
+
318
+ # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
319
+ if ni - last_opt_step >= accumulate:
320
+ scaler.unscale_(optimizer) # unscale gradients
321
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
322
+ scaler.step(optimizer) # optimizer.step
323
+ scaler.update()
324
+ optimizer.zero_grad()
325
+ if ema:
326
+ ema.update(model)
327
+ last_opt_step = ni
328
+
329
+ # Log
330
+ if RANK in {-1, 0}:
331
+ mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
332
+ mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
333
+ pbar.set_description(('%11s' * 2 + '%11.4g' * 8) %
334
+ (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
335
+ # callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
336
+ # if callbacks.stop_training:
337
+ # return
338
+
339
+ # Mosaic plots
340
+ if plots:
341
+ if ni < 10:
342
+ plot_images_and_masks(imgs, targets, masks, semasks, paths, save_dir / f"train_batch{ni}.jpg")
343
+ if ni == 10:
344
+ files = sorted(save_dir.glob('train*.jpg'))
345
+ logger.log_images(files, "Mosaics", epoch)
346
+ # end batch ------------------------------------------------------------------------------------------------
347
+
348
+ # Scheduler
349
+ lr = [x['lr'] for x in optimizer.param_groups] # for loggers
350
+ scheduler.step()
351
+
352
+ if RANK in {-1, 0}:
353
+ # mAP
354
+ # callbacks.run('on_train_epoch_end', epoch=epoch)
355
+ ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
356
+ final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
357
+ if not noval or final_epoch: # Calculate mAP
358
+ if (opt.save_period > 0 and epoch % opt.save_period == 0) or (epoch > (epochs - 2 * opt.close_mosaic)):
359
+ results, maps, _ = validate.run(data_dict,
360
+ batch_size=batch_size // WORLD_SIZE * 2,
361
+ imgsz=imgsz,
362
+ half=amp,
363
+ model=ema.ema,
364
+ single_cls=single_cls,
365
+ dataloader=val_loader,
366
+ save_dir=save_dir,
367
+ plots=False,
368
+ callbacks=callbacks,
369
+ compute_loss=compute_loss,
370
+ mask_downsample_ratio=mask_ratio,
371
+ overlap=overlap)
372
+
373
+ # Update best mAP
374
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
375
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
376
+ if fi > best_fitness:
377
+ best_fitness = fi
378
+ log_vals = list(mloss) + list(results) + lr
379
+ # callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
380
+ # Log val metrics and media
381
+ metrics_dict = dict(zip(KEYS, log_vals))
382
+ logger.log_metrics(metrics_dict, epoch)
383
+
384
+ # Save model
385
+ if (not nosave) or (final_epoch and not evolve): # if save
386
+ ckpt = {
387
+ 'epoch': epoch,
388
+ 'best_fitness': best_fitness,
389
+ 'model': deepcopy(de_parallel(model)).half(),
390
+ 'ema': deepcopy(ema.ema).half(),
391
+ 'updates': ema.updates,
392
+ 'optimizer': optimizer.state_dict(),
393
+ 'opt': vars(opt),
394
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
395
+ 'date': datetime.now().isoformat()}
396
+
397
+ # Save last, best and delete
398
+ torch.save(ckpt, last)
399
+ if best_fitness == fi:
400
+ torch.save(ckpt, best)
401
+ if opt.save_period > 0 and epoch % opt.save_period == 0:
402
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
403
+ logger.log_model(w / f'epoch{epoch}.pt')
404
+ del ckpt
405
+ # callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
406
+
407
+ # EarlyStopping
408
+ if RANK != -1: # if DDP training
409
+ broadcast_list = [stop if RANK == 0 else None]
410
+ dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
411
+ if RANK != 0:
412
+ stop = broadcast_list[0]
413
+ if stop:
414
+ break # must break all DDP ranks
415
+
416
+ # end epoch ----------------------------------------------------------------------------------------------------
417
+ # end training -----------------------------------------------------------------------------------------------------
418
+ if RANK in {-1, 0}:
419
+ LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
420
+ for f in last, best:
421
+ if f.exists():
422
+ strip_optimizer(f) # strip optimizers
423
+ if f is best:
424
+ LOGGER.info(f'\nValidating {f}...')
425
+ results, _, _ = validate.run(
426
+ data_dict,
427
+ batch_size=batch_size // WORLD_SIZE * 2,
428
+ imgsz=imgsz,
429
+ model=attempt_load(f, device).half(),
430
+ iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
431
+ single_cls=single_cls,
432
+ dataloader=val_loader,
433
+ save_dir=save_dir,
434
+ save_json=is_coco,
435
+ verbose=True,
436
+ plots=plots,
437
+ callbacks=callbacks,
438
+ compute_loss=compute_loss,
439
+ mask_downsample_ratio=mask_ratio,
440
+ overlap=overlap) # val best model with plots
441
+ if is_coco:
442
+ # callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
443
+ metrics_dict = dict(zip(KEYS, list(mloss) + list(results) + lr))
444
+ logger.log_metrics(metrics_dict, epoch)
445
+
446
+ # callbacks.run('on_train_end', last, best, epoch, results)
447
+ # on train end callback using genericLogger
448
+ logger.log_metrics(dict(zip(KEYS[6:22], results)), epochs)
449
+ if not opt.evolve:
450
+ logger.log_model(best, epoch)
451
+ if plots:
452
+ plot_results_with_masks(file=save_dir / 'results.csv') # save results.png
453
+ files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
454
+ files = [(save_dir / f) for f in files if (save_dir / f).exists()] # filter
455
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
456
+ logger.log_images(files, "Results", epoch + 1)
457
+ logger.log_images(sorted(save_dir.glob('val*.jpg')), "Validation", epoch + 1)
458
+ torch.cuda.empty_cache()
459
+ return results
460
+
461
+
462
+ def parse_opt(known=False):
463
+ parser = argparse.ArgumentParser()
464
+ parser.add_argument('--weights', type=str, default=ROOT / 'yolo-pan.pt', help='initial weights path')
465
+ parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
466
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path')
467
+ parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
468
+ parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
469
+ parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
470
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
471
+ parser.add_argument('--rect', action='store_true', help='rectangular training')
472
+ parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
473
+ parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
474
+ parser.add_argument('--noval', action='store_true', help='only validate final epoch')
475
+ parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
476
+ parser.add_argument('--noplots', action='store_true', help='save no plot files')
477
+ parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
478
+ parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
479
+ parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
480
+ parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
481
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
482
+ parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
483
+ parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
484
+ parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW', 'LION'], default='SGD', help='optimizer')
485
+ parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
486
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
487
+ parser.add_argument('--project', default=ROOT / 'runs/train-pan', help='save to project/name')
488
+ parser.add_argument('--name', default='exp', help='save to project/name')
489
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
490
+ parser.add_argument('--quad', action='store_true', help='quad dataloader')
491
+ parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
492
+ parser.add_argument('--flat-cos-lr', action='store_true', help='cosine LR scheduler')
493
+ parser.add_argument('--fixed-lr', action='store_true', help='fixed LR scheduler')
494
+ parser.add_argument('--poly-lr', action='store_true', help='fixed LR scheduler')
495
+ parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
496
+ parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
497
+ parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
498
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
499
+ parser.add_argument('--seed', type=int, default=0, help='Global training seed')
500
+ parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
501
+ parser.add_argument('--close-mosaic', type=int, default=0, help='Experimental')
502
+
503
+ # Instance Segmentation Args
504
+ parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
505
+ parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')
506
+
507
+ return parser.parse_known_args()[0] if known else parser.parse_args()
508
+
509
+
510
+ def main(opt, callbacks=Callbacks()):
511
+ # Checks
512
+ if RANK in {-1, 0}:
513
+ print_args(vars(opt))
514
+ #check_git_status()
515
+ #check_requirements()
516
+
517
+ # Resume
518
+ if opt.resume and not opt.evolve: # resume from specified or most recent last.pt
519
+ last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
520
+ opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
521
+ opt_data = opt.data # original dataset
522
+ if opt_yaml.is_file():
523
+ with open(opt_yaml, errors='ignore') as f:
524
+ d = yaml.safe_load(f)
525
+ else:
526
+ d = torch.load(last, map_location='cpu')['opt']
527
+ opt = argparse.Namespace(**d) # replace
528
+ opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
529
+ if is_url(opt_data):
530
+ opt.data = check_file(opt_data) # avoid HUB resume auth timeout
531
+ else:
532
+ opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
533
+ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
534
+ assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
535
+ if opt.evolve:
536
+ if opt.project == str(ROOT / 'runs/train'): # if default project name, rename to runs/evolve
537
+ opt.project = str(ROOT / 'runs/evolve')
538
+ opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
539
+ if opt.name == 'cfg':
540
+ opt.name = Path(opt.cfg).stem # use model.yaml as name
541
+ opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
542
+
543
+ # DDP mode
544
+ device = select_device(opt.device, batch_size=opt.batch_size)
545
+ if LOCAL_RANK != -1:
546
+ msg = 'is not compatible with YOLO Multi-GPU DDP training'
547
+ assert not opt.image_weights, f'--image-weights {msg}'
548
+ assert not opt.evolve, f'--evolve {msg}'
549
+ assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
550
+ assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
551
+ assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
552
+ torch.cuda.set_device(LOCAL_RANK)
553
+ device = torch.device('cuda', LOCAL_RANK)
554
+ dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
555
+
556
+ # Train
557
+ if not opt.evolve:
558
+ train(opt.hyp, opt, device, callbacks)
559
+
560
+ # Evolve hyperparameters (optional)
561
+ else:
562
+ # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
563
+ meta = {
564
+ 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
565
+ 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
566
+ 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
567
+ 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
568
+ 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
569
+ 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
570
+ 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
571
+ 'box': (1, 0.02, 0.2), # box loss gain
572
+ 'cls': (1, 0.2, 4.0), # cls loss gain
573
+ 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
574
+ 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
575
+ 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
576
+ 'iou_t': (0, 0.1, 0.7), # IoU training threshold
577
+ 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
578
+ 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
579
+ 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
580
+ 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
581
+ 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
582
+ 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
583
+ 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
584
+ 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
585
+ 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
586
+ 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
587
+ 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
588
+ 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
589
+ 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
590
+ 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
591
+ 'mixup': (1, 0.0, 1.0), # image mixup (probability)
592
+ 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
593
+
594
+ with open(opt.hyp, errors='ignore') as f:
595
+ hyp = yaml.safe_load(f) # load hyps dict
596
+ if 'anchors' not in hyp: # anchors commented in hyp.yaml
597
+ hyp['anchors'] = 3
598
+ if opt.noautoanchor:
599
+ del hyp['anchors'], meta['anchors']
600
+ opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
601
+ # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
602
+ evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
603
+ if opt.bucket:
604
+ os.system(f'gsutil cp gs://{opt.bucket}/evolve.csv {evolve_csv}') # download evolve.csv if exists
605
+
606
+ for _ in range(opt.evolve): # generations to evolve
607
+ if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
608
+ # Select parent(s)
609
+ parent = 'single' # parent selection method: 'single' or 'weighted'
610
+ x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
611
+ n = min(5, len(x)) # number of previous results to consider
612
+ x = x[np.argsort(-fitness(x))][:n] # top n mutations
613
+ w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
614
+ if parent == 'single' or len(x) == 1:
615
+ # x = x[random.randint(0, n - 1)] # random selection
616
+ x = x[random.choices(range(n), weights=w)[0]] # weighted selection
617
+ elif parent == 'weighted':
618
+ x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
619
+
620
+ # Mutate
621
+ mp, s = 0.8, 0.2 # mutation probability, sigma
622
+ npr = np.random
623
+ npr.seed(int(time.time()))
624
+ g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
625
+ ng = len(meta)
626
+ v = np.ones(ng)
627
+ while all(v == 1): # mutate until a change occurs (prevent duplicates)
628
+ v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
629
+ for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
630
+ hyp[k] = float(x[i + 7] * v[i]) # mutate
631
+
632
+ # Constrain to limits
633
+ for k, v in meta.items():
634
+ hyp[k] = max(hyp[k], v[1]) # lower limit
635
+ hyp[k] = min(hyp[k], v[2]) # upper limit
636
+ hyp[k] = round(hyp[k], 5) # significant digits
637
+
638
+ # Train mutation
639
+ results = train(hyp.copy(), opt, device, callbacks)
640
+ callbacks = Callbacks()
641
+ # Write mutation results
642
+ print_mutation(KEYS, results, hyp.copy(), save_dir, opt.bucket)
643
+
644
+ # Plot results
645
+ plot_evolve(evolve_csv)
646
+ LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
647
+ f"Results saved to {colorstr('bold', save_dir)}\n"
648
+ f'Usage example: $ python train.py --hyp {evolve_yaml}')
649
+
650
+
651
+ def run(**kwargs):
652
+ # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolo.pt')
653
+ opt = parse_opt(True)
654
+ for k, v in kwargs.items():
655
+ setattr(opt, k, v)
656
+ main(opt)
657
+ return opt
658
+
659
+
660
+ if __name__ == "__main__":
661
+ opt = parse_opt()
662
+ main(opt)
yolov9/panoptic/val.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+ from multiprocessing.pool import ThreadPool
6
+ from pathlib import Path
7
+
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ FILE = Path(__file__).resolve()
13
+ ROOT = FILE.parents[1] # YOLO root directory
14
+ if str(ROOT) not in sys.path:
15
+ sys.path.append(str(ROOT)) # add ROOT to PATH
16
+ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
17
+
18
+ import torch.nn.functional as F
19
+ import torchvision.transforms as transforms
20
+ from pycocotools import mask as maskUtils
21
+ from models.common import DetectMultiBackend
22
+ from models.yolo import SegmentationModel
23
+ from utils.callbacks import Callbacks
24
+ from utils.coco_utils import getCocoIds, getMappingId, getMappingIndex
25
+ from utils.general import (LOGGER, NUM_THREADS, TQDM_BAR_FORMAT, Profile, check_dataset, check_img_size,
26
+ check_requirements, check_yaml, coco80_to_coco91_class, colorstr, increment_path,
27
+ non_max_suppression, print_args, scale_boxes, xywh2xyxy, xyxy2xywh)
28
+ from utils.metrics import ConfusionMatrix, box_iou
29
+ from utils.plots import output_to_target, plot_val_study
30
+ from utils.panoptic.dataloaders import create_dataloader
31
+ from utils.panoptic.general import mask_iou, process_mask, process_mask_upsample, scale_image
32
+ from utils.panoptic.metrics import Metrics, ap_per_class_box_and_mask, Semantic_Metrics
33
+ from utils.panoptic.plots import plot_images_and_masks
34
+ from utils.torch_utils import de_parallel, select_device, smart_inference_mode
35
+
36
+
37
+ def save_one_txt(predn, save_conf, shape, file):
38
+ # Save one txt result
39
+ gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
40
+ for *xyxy, conf, cls in predn.tolist():
41
+ xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
42
+ line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
43
+ with open(file, 'a') as f:
44
+ f.write(('%g ' * len(line)).rstrip() % line + '\n')
45
+
46
+
47
+ def save_one_json(predn, jdict, path, class_map, pred_masks):
48
+ # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
49
+ from pycocotools.mask import encode
50
+
51
+ def single_encode(x):
52
+ rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
53
+ rle["counts"] = rle["counts"].decode("utf-8")
54
+ return rle
55
+
56
+ image_id = int(path.stem) if path.stem.isnumeric() else path.stem
57
+ box = xyxy2xywh(predn[:, :4]) # xywh
58
+ box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
59
+ pred_masks = np.transpose(pred_masks, (2, 0, 1))
60
+ with ThreadPool(NUM_THREADS) as pool:
61
+ rles = pool.map(single_encode, pred_masks)
62
+ for i, (p, b) in enumerate(zip(predn.tolist(), box.tolist())):
63
+ jdict.append({
64
+ 'image_id': image_id,
65
+ 'category_id': class_map[int(p[5])],
66
+ 'bbox': [round(x, 3) for x in b],
67
+ 'score': round(p[4], 5),
68
+ 'segmentation': rles[i]})
69
+
70
+
71
+ def process_batch(detections, labels, iouv, pred_masks=None, gt_masks=None, overlap=False, masks=False):
72
+ """
73
+ Return correct prediction matrix
74
+ Arguments:
75
+ detections (array[N, 6]), x1, y1, x2, y2, conf, class
76
+ labels (array[M, 5]), class, x1, y1, x2, y2
77
+ Returns:
78
+ correct (array[N, 10]), for 10 IoU levels
79
+ """
80
+ if masks:
81
+ if overlap:
82
+ nl = len(labels)
83
+ index = torch.arange(nl, device=gt_masks.device).view(nl, 1, 1) + 1
84
+ gt_masks = gt_masks.repeat(nl, 1, 1) # shape(1,640,640) -> (n,640,640)
85
+ gt_masks = torch.where(gt_masks == index, 1.0, 0.0)
86
+ if gt_masks.shape[1:] != pred_masks.shape[1:]:
87
+ gt_masks = F.interpolate(gt_masks[None], pred_masks.shape[1:], mode="bilinear", align_corners=False)[0]
88
+ gt_masks = gt_masks.gt_(0.5)
89
+ iou = mask_iou(gt_masks.view(gt_masks.shape[0], -1), pred_masks.view(pred_masks.shape[0], -1))
90
+ else: # boxes
91
+ iou = box_iou(labels[:, 1:], detections[:, :4])
92
+
93
+ correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
94
+ correct_class = labels[:, 0:1] == detections[:, 5]
95
+ for i in range(len(iouv)):
96
+ x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
97
+ if x[0].shape[0]:
98
+ matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
99
+ if x[0].shape[0] > 1:
100
+ matches = matches[matches[:, 2].argsort()[::-1]]
101
+ matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
102
+ # matches = matches[matches[:, 2].argsort()[::-1]]
103
+ matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
104
+ correct[matches[:, 1].astype(int), i] = True
105
+ return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
106
+
107
+
108
+ @smart_inference_mode()
109
+ def run(
110
+ data,
111
+ weights=None, # model.pt path(s)
112
+ batch_size=32, # batch size
113
+ imgsz=640, # inference size (pixels)
114
+ conf_thres=0.001, # confidence threshold
115
+ iou_thres=0.6, # NMS IoU threshold
116
+ max_det=300, # maximum detections per image
117
+ task='val', # train, val, test, speed or study
118
+ device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
119
+ workers=8, # max dataloader workers (per RANK in DDP mode)
120
+ single_cls=False, # treat as single-class dataset
121
+ augment=False, # augmented inference
122
+ verbose=False, # verbose output
123
+ save_txt=False, # save results to *.txt
124
+ save_hybrid=False, # save label+prediction hybrid results to *.txt
125
+ save_conf=False, # save confidences in --save-txt labels
126
+ save_json=False, # save a COCO-JSON results file
127
+ project=ROOT / 'runs/val-pan', # save to project/name
128
+ name='exp', # save to project/name
129
+ exist_ok=False, # existing project/name ok, do not increment
130
+ half=True, # use FP16 half-precision inference
131
+ dnn=False, # use OpenCV DNN for ONNX inference
132
+ model=None,
133
+ dataloader=None,
134
+ save_dir=Path(''),
135
+ plots=True,
136
+ overlap=False,
137
+ mask_downsample_ratio=1,
138
+ compute_loss=None,
139
+ callbacks=Callbacks(),
140
+ ):
141
+ if save_json:
142
+ check_requirements(['pycocotools'])
143
+ process = process_mask_upsample # more accurate
144
+ else:
145
+ process = process_mask # faster
146
+
147
+ # Initialize/load model and set device
148
+ training = model is not None
149
+ if training: # called by train.py
150
+ device, pt, jit, engine = next(model.parameters()).device, True, False, False # get model device, PyTorch model
151
+ half &= device.type != 'cpu' # half precision only supported on CUDA
152
+ model.half() if half else model.float()
153
+ nm = de_parallel(model).model[-1].nm # number of masks
154
+ else: # called directly
155
+ device = select_device(device, batch_size=batch_size)
156
+
157
+ # Directories
158
+ save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
159
+ (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
160
+
161
+ # Load model
162
+ model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
163
+ stride, pt, jit, engine = model.stride, model.pt, model.jit, model.engine
164
+ imgsz = check_img_size(imgsz, s=stride) # check image size
165
+ half = model.fp16 # FP16 supported on limited backends with CUDA
166
+ nm = de_parallel(model).model.model[-1].nm if isinstance(model, SegmentationModel) else 32 # number of masks
167
+ if engine:
168
+ batch_size = model.batch_size
169
+ else:
170
+ device = model.device
171
+ if not (pt or jit):
172
+ batch_size = 1 # export.py models default to batch-size 1
173
+ LOGGER.info(f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
174
+
175
+ # Data
176
+ data = check_dataset(data) # check
177
+
178
+ # Configure
179
+ model.eval()
180
+ cuda = device.type != 'cpu'
181
+ #is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'coco{os.sep}val2017.txt') # COCO dataset
182
+ is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'val2017.txt') # COCO dataset
183
+ nc = 1 if single_cls else int(data['nc']) # number of classes
184
+ stuff_names = data.get('stuff_names', []) # names of stuff classes
185
+ stuff_nc = len(stuff_names) # number of stuff classes
186
+ iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95
187
+ niou = iouv.numel()
188
+
189
+ # Semantic Segmentation
190
+ img_id_list = []
191
+
192
+ # Dataloader
193
+ if not training:
194
+ if pt and not single_cls: # check --weights are trained on --data
195
+ ncm = model.model.nc
196
+ assert ncm == nc, f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' \
197
+ f'classes). Pass correct combination of --weights and --data that are trained together.'
198
+ model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
199
+ pad, rect = (0.0, False) if task == 'speed' else (0.5, pt) # square inference for benchmarks
200
+ task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
201
+ dataloader = create_dataloader(data[task],
202
+ imgsz,
203
+ batch_size,
204
+ stride,
205
+ single_cls,
206
+ pad=pad,
207
+ rect=rect,
208
+ workers=workers,
209
+ prefix=colorstr(f'{task}: '),
210
+ overlap_mask=overlap,
211
+ mask_downsample_ratio=mask_downsample_ratio)[0]
212
+
213
+ seen = 0
214
+ confusion_matrix = ConfusionMatrix(nc=nc)
215
+ names = model.names if hasattr(model, 'names') else model.module.names # get class names
216
+ if isinstance(names, (list, tuple)): # old format
217
+ names = dict(enumerate(names))
218
+ class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
219
+ s = ('%22s' + '%11s' * 12) % ('Class', 'Images', 'Instances', 'Box(P', "R", "mAP50", "mAP50-95)", "Mask(P", "R",
220
+ "mAP50", "mAP50-95)", 'S(MIoU', 'FWIoU)')
221
+ dt = Profile(), Profile(), Profile()
222
+ metrics = Metrics()
223
+ semantic_metrics = Semantic_Metrics(nc = (nc + stuff_nc), device = device)
224
+ loss = torch.zeros(6, device=device)
225
+ jdict, stats = [], []
226
+ semantic_jdict = []
227
+ # callbacks.run('on_val_start')
228
+ pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar
229
+ for batch_i, (im, targets, paths, shapes, masks, semasks) in enumerate(pbar):
230
+ # callbacks.run('on_val_batch_start')
231
+ with dt[0]:
232
+ if cuda:
233
+ im = im.to(device, non_blocking=True)
234
+ targets = targets.to(device)
235
+ masks = masks.to(device)
236
+ semasks = semasks.to(device)
237
+ masks = masks.float()
238
+ semasks = semasks.float()
239
+ im = im.half() if half else im.float() # uint8 to fp16/32
240
+ im /= 255 # 0 - 255 to 0.0 - 1.0
241
+ nb, _, height, width = im.shape # batch size, channels, height, width
242
+
243
+ # Inference
244
+ with dt[1]:
245
+ preds, train_out = model(im)# if compute_loss else (*model(im, augment=augment)[:2], None)
246
+ #train_out, preds, protos = p if len(p) == 3 else p[1]
247
+ #preds = p
248
+ #train_out = p[1][0] if len(p[1]) == 3 else p[0]
249
+ # protos = train_out[-1]
250
+ #print(preds.shape)
251
+ #print(train_out[0].shape)
252
+ #print(train_out[1].shape)
253
+ #print(train_out[2].shape)
254
+ _, pred_masks, protos, psemasks = train_out
255
+
256
+ # Loss
257
+ if compute_loss:
258
+ loss += compute_loss(train_out, targets, masks, semasks = semasks)[1] # box, obj, cls
259
+
260
+ # NMS
261
+ targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
262
+ lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
263
+ with dt[2]:
264
+ preds = non_max_suppression(preds,
265
+ conf_thres,
266
+ iou_thres,
267
+ labels=lb,
268
+ multi_label=True,
269
+ agnostic=single_cls,
270
+ max_det=max_det,
271
+ nm=nm)
272
+
273
+ # Metrics
274
+ plot_masks = [] # masks for plotting
275
+ plot_semasks = [] # masks for plotting
276
+
277
+ if training:
278
+ semantic_metrics.update(psemasks, semasks)
279
+ else:
280
+ _, _, smh, smw = semasks.shape
281
+ semantic_metrics.update(torch.nn.functional.interpolate(psemasks, size = (smh, smw), mode = 'bilinear', align_corners = False), semasks)
282
+
283
+ if plots and batch_i < 3:
284
+ plot_semasks.append(psemasks.clone().detach().cpu())
285
+
286
+ for si, (pred, proto, psemask) in enumerate(zip(preds, protos, psemasks)):
287
+ labels = targets[targets[:, 0] == si, 1:]
288
+ nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
289
+ path, shape = Path(paths[si]), shapes[si][0]
290
+ image_id = path.stem
291
+ img_id_list.append(image_id)
292
+ correct_masks = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
293
+ correct_bboxes = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
294
+ seen += 1
295
+
296
+ if npr == 0:
297
+ if nl:
298
+ stats.append((correct_masks, correct_bboxes, *torch.zeros((2, 0), device=device), labels[:, 0]))
299
+ if plots:
300
+ confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
301
+ else:
302
+ # Masks
303
+ midx = [si] if overlap else targets[:, 0] == si
304
+ gt_masks = masks[midx]
305
+ pred_masks = process(proto, pred[:, 6:], pred[:, :4], shape=im[si].shape[1:])
306
+
307
+ # Predictions
308
+ if single_cls:
309
+ pred[:, 5] = 0
310
+ predn = pred.clone()
311
+ scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
312
+
313
+ # Evaluate
314
+ if nl:
315
+ tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
316
+ scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
317
+ labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
318
+ correct_bboxes = process_batch(predn, labelsn, iouv)
319
+ correct_masks = process_batch(predn, labelsn, iouv, pred_masks, gt_masks, overlap=overlap, masks=True)
320
+ if plots:
321
+ confusion_matrix.process_batch(predn, labelsn)
322
+ stats.append((correct_masks, correct_bboxes, pred[:, 4], pred[:, 5], labels[:, 0])) # (conf, pcls, tcls)
323
+
324
+ pred_masks = torch.as_tensor(pred_masks, dtype=torch.uint8)
325
+ if plots and batch_i < 3:
326
+ plot_masks.append(pred_masks[:15].cpu()) # filter top 15 to plot
327
+
328
+ # Save/log
329
+ if save_txt:
330
+ save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
331
+ if save_json:
332
+ pred_masks = scale_image(im[si].shape[1:],
333
+ pred_masks.permute(1, 2, 0).contiguous().cpu().numpy(), shape, shapes[si][1])
334
+ save_one_json(predn, jdict, path, class_map, pred_masks) # append to COCO-JSON dictionary
335
+ # callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
336
+
337
+ # Semantic Segmentation
338
+ h0, w0 = shape
339
+
340
+ # resize
341
+ _, mask_h, mask_w = psemask.shape
342
+ h_ratio = mask_h / h0
343
+ w_ratio = mask_w / w0
344
+
345
+ if h_ratio == w_ratio:
346
+ psemask = torch.nn.functional.interpolate(psemask[None, :], size = (h0, w0), mode = 'bilinear', align_corners = False)
347
+ else:
348
+ transform = transforms.CenterCrop((h0, w0))
349
+
350
+ if (1 != h_ratio) and (1 != w_ratio):
351
+ h_new = h0 if (h_ratio < w_ratio) else int(mask_h / w_ratio)
352
+ w_new = w0 if (h_ratio > w_ratio) else int(mask_w / h_ratio)
353
+ psemask = torch.nn.functional.interpolate(psemask[None, :], size = (h_new, w_new), mode = 'bilinear', align_corners = False)
354
+
355
+ psemask = transform(psemask)
356
+
357
+ psemask = torch.squeeze(psemask)
358
+
359
+ nc, h, w = psemask.shape
360
+
361
+ semantic_mask = torch.flatten(psemask, start_dim = 1).permute(1, 0) # class x h x w -> (h x w) x class
362
+
363
+ max_idx = semantic_mask.argmax(1)
364
+ output_masks = torch.zeros(semantic_mask.shape).scatter(1, max_idx.cpu().unsqueeze(1), 1.0) # one hot: (h x w) x class
365
+ output_masks = torch.reshape(output_masks.permute(1, 0), (nc, h, w)) # (h x w) x class -> class x h x w
366
+ psemask = output_masks.to(device = device)
367
+
368
+ # TODO: check is_coco
369
+ instances_ids = getCocoIds(name = 'instances')
370
+ stuff_mask = torch.zeros((h, w), device = device)
371
+ check_semantic_mask = False
372
+ for idx, pred_semantic_mask in enumerate(psemask):
373
+ category_id = int(getMappingId(idx))
374
+ if 183 == category_id:
375
+ # set all non-stuff pixels to other
376
+ pred_semantic_mask = (torch.logical_xor(stuff_mask, torch.ones((h, w), device = device))).int()
377
+
378
+ # ignore the classes which all zeros / unlabeled class
379
+ if (0 >= torch.max(pred_semantic_mask)) or (0 >= category_id):
380
+ continue
381
+
382
+ if category_id not in instances_ids:
383
+ # record all stuff mask
384
+ stuff_mask = torch.logical_or(stuff_mask, pred_semantic_mask)
385
+
386
+ if (category_id not in instances_ids):
387
+ rle = maskUtils.encode(np.asfortranarray(pred_semantic_mask.cpu(), dtype = np.uint8))
388
+ rle['counts'] = rle['counts'].decode('utf-8')
389
+
390
+ temp_d = {
391
+ 'image_id': int(image_id) if image_id.isnumeric() else image_id,
392
+ 'category_id': category_id,
393
+ 'segmentation': rle,
394
+ 'score': 1
395
+ }
396
+
397
+ semantic_jdict.append(temp_d)
398
+ check_semantic_mask = True
399
+
400
+ if not check_semantic_mask:
401
+ # append a other mask for evaluation if the image without any mask
402
+ other_mask = (torch.ones((h, w), device = device)).int()
403
+
404
+ rle = maskUtils.encode(np.asfortranarray(other_mask.cpu(), dtype = np.uint8))
405
+ rle['counts'] = rle['counts'].decode('utf-8')
406
+
407
+ temp_d = {
408
+ 'image_id': int(image_id) if image_id.isnumeric() else image_id,
409
+ 'category_id': 183,
410
+ 'segmentation': rle,
411
+ 'score': 1
412
+ }
413
+
414
+ semantic_jdict.append(temp_d)
415
+
416
+ # Plot images
417
+ if plots and batch_i < 3:
418
+ if len(plot_masks):
419
+ plot_masks = torch.cat(plot_masks, dim=0)
420
+ if len(plot_semasks):
421
+ plot_semasks = torch.cat(plot_semasks, dim = 0)
422
+ plot_images_and_masks(im, targets, masks, semasks, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names)
423
+ plot_images_and_masks(im, output_to_target(preds, max_det=15), plot_masks, plot_semasks, paths,
424
+ save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
425
+
426
+ # callbacks.run('on_val_batch_end')
427
+
428
+ # Compute metrics
429
+ stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
430
+ if len(stats) and stats[0].any():
431
+ results = ap_per_class_box_and_mask(*stats, plot=plots, save_dir=save_dir, names=names)
432
+ metrics.update(results)
433
+ nt = np.bincount(stats[4].astype(int), minlength=nc) # number of targets per class
434
+
435
+ # Print results
436
+ pf = '%22s' + '%11i' * 2 + '%11.3g' * 10 # print format
437
+ LOGGER.info(pf % ("all", seen, nt.sum(), *metrics.mean_results(), *semantic_metrics.results()))
438
+ if nt.sum() == 0:
439
+ LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')
440
+
441
+ # Print results per class
442
+ if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
443
+ for i, c in enumerate(metrics.ap_class_index):
444
+ LOGGER.info(pf % (names[c], seen, nt[c], *metrics.class_result(i), *semantic_metrics.results()))
445
+
446
+ # Print speeds
447
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
448
+ if not training:
449
+ shape = (batch_size, 3, imgsz, imgsz)
450
+ LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
451
+
452
+ # Plots
453
+ if plots:
454
+ confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
455
+ # callbacks.run('on_val_end')
456
+
457
+ mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask = metrics.mean_results()
458
+ miou_sem, fwiou_sem = semantic_metrics.results()
459
+ semantic_metrics.reset()
460
+
461
+ # Save JSON
462
+ if save_json and len(jdict):
463
+ w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
464
+ anno_path = Path(data.get('path', '../coco'))
465
+ anno_json = str(anno_path / 'annotations/instances_val2017.json') # annotations json
466
+ pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
467
+ LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
468
+ with open(pred_json, 'w') as f:
469
+ json.dump(jdict, f)
470
+
471
+ semantic_anno_json = str(anno_path / 'annotations/stuff_val2017.json') # annotations json
472
+ semantic_pred_json = str(save_dir / f"{w}_predictions_stuff.json") # predictions json
473
+ LOGGER.info(f'\nsaving {semantic_pred_json}...')
474
+ with open(semantic_pred_json, 'w') as f:
475
+ json.dump(semantic_jdict, f)
476
+
477
+ try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
478
+ from pycocotools.coco import COCO
479
+ from pycocotools.cocoeval import COCOeval
480
+
481
+ anno = COCO(anno_json) # init annotations api
482
+ pred = anno.loadRes(pred_json) # init predictions api
483
+ results = []
484
+ for eval in COCOeval(anno, pred, 'bbox'), COCOeval(anno, pred, 'segm'):
485
+ if is_coco:
486
+ eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # img ID to evaluate
487
+ eval.evaluate()
488
+ eval.accumulate()
489
+ eval.summarize()
490
+ results.extend(eval.stats[:2]) # update results (mAP@0.5:0.95, mAP@0.5)
491
+ map_bbox, map50_bbox, map_mask, map50_mask = results
492
+
493
+ # Semantic Segmentation
494
+ from utils.stuff_seg.cocostuffeval import COCOStuffeval
495
+
496
+ LOGGER.info(f'\nEvaluating pycocotools stuff... ')
497
+ imgIds = [int(x) for x in img_id_list]
498
+
499
+ stuffGt = COCO(semantic_anno_json) # initialize COCO ground truth api
500
+ stuffDt = stuffGt.loadRes(semantic_pred_json) # initialize COCO pred api
501
+
502
+ cocoStuffEval = COCOStuffeval(stuffGt, stuffDt)
503
+ cocoStuffEval.params.imgIds = imgIds # image IDs to evaluate
504
+ cocoStuffEval.evaluate()
505
+ stats, statsClass = cocoStuffEval.summarize()
506
+ stuffIds = getCocoIds(name = 'stuff')
507
+ title = ' {:<5} | {:^6} | {:^6} '.format('class', 'iou', 'macc') if (0 >= len(stuff_names)) else \
508
+ ' {:<5} | {:<20} | {:^6} | {:^6} '.format('class', 'class name', 'iou', 'macc')
509
+ print(title)
510
+ for idx, (iou, macc) in enumerate(zip(statsClass['ious'], statsClass['maccs'])):
511
+ id = (idx + 1)
512
+ if id not in stuffIds:
513
+ continue
514
+ content = ' {:<5} | {:0.4f} | {:0.4f} '.format(str(id), iou, macc) if (0 >= len(stuff_names)) else \
515
+ ' {:<5} | {:<20} | {:0.4f} | {:0.4f} '.format(str(id), str(stuff_names[getMappingIndex(id, name = 'stuff')]), iou, macc)
516
+ print(content)
517
+
518
+ except Exception as e:
519
+ LOGGER.info(f'pycocotools unable to run: {e}')
520
+
521
+ # Return results
522
+ model.float() # for training
523
+ if not training:
524
+ s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
525
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
526
+ final_metric = mp_bbox, mr_bbox, map50_bbox, map_bbox, mp_mask, mr_mask, map50_mask, map_mask, miou_sem, fwiou_sem
527
+ return (*final_metric, *(loss.cpu() / len(dataloader)).tolist()), metrics.get_maps(nc), t
528
+
529
+
530
+ def parse_opt():
531
+ parser = argparse.ArgumentParser()
532
+ parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-pan.yaml', help='dataset.yaml path')
533
+ parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo-pan.pt', help='model path(s)')
534
+ parser.add_argument('--batch-size', type=int, default=32, help='batch size')
535
+ parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)')
536
+ parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
537
+ parser.add_argument('--iou-thres', type=float, default=0.6, help='NMS IoU threshold')
538
+ parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image')
539
+ parser.add_argument('--task', default='val', help='train, val, test, speed or study')
540
+ parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
541
+ parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
542
+ parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
543
+ parser.add_argument('--augment', action='store_true', help='augmented inference')
544
+ parser.add_argument('--verbose', action='store_true', help='report mAP by class')
545
+ parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
546
+ parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
547
+ parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
548
+ parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file')
549
+ parser.add_argument('--project', default=ROOT / 'runs/val-pan', help='save results to project/name')
550
+ parser.add_argument('--name', default='exp', help='save to project/name')
551
+ parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
552
+ parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
553
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
554
+ opt = parser.parse_args()
555
+ opt.data = check_yaml(opt.data) # check YAML
556
+ # opt.save_json |= opt.data.endswith('coco.yaml')
557
+ opt.save_txt |= opt.save_hybrid
558
+ print_args(vars(opt))
559
+ return opt
560
+
561
+
562
+ def main(opt):
563
+ #check_requirements(requirements=ROOT / 'requirements.txt', exclude=('tensorboard', 'thop'))
564
+
565
+ if opt.task in ('train', 'val', 'test'): # run normally
566
+ if opt.conf_thres > 0.001: # https://github.com/
567
+ LOGGER.warning(f'WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results')
568
+ if opt.save_hybrid:
569
+ LOGGER.warning('WARNING ⚠️ --save-hybrid returns high mAP from hybrid labels, not from predictions alone')
570
+ run(**vars(opt))
571
+
572
+ else:
573
+ weights = opt.weights if isinstance(opt.weights, list) else [opt.weights]
574
+ opt.half = torch.cuda.is_available() and opt.device != 'cpu' # FP16 for fastest results
575
+ if opt.task == 'speed': # speed benchmarks
576
+ # python val.py --task speed --data coco.yaml --batch 1 --weights yolo.pt...
577
+ opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False
578
+ for opt.weights in weights:
579
+ run(**vars(opt), plots=False)
580
+
581
+ elif opt.task == 'study': # speed vs mAP benchmarks
582
+ # python val.py --task study --data coco.yaml --iou 0.7 --weights yolo.pt...
583
+ for opt.weights in weights:
584
+ f = f'study_{Path(opt.data).stem}_{Path(opt.weights).stem}.txt' # filename to save to
585
+ x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis
586
+ for opt.imgsz in x: # img-size
587
+ LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...')
588
+ r, _, t = run(**vars(opt), plots=False)
589
+ y.append(r + t) # results and times
590
+ np.savetxt(f, y, fmt='%10.4g') # save
591
+ os.system('zip -r study.zip study_*.txt')
592
+ plot_val_study(x=x) # plot
593
+
594
+
595
+ if __name__ == "__main__":
596
+ opt = parse_opt()
597
+ main(opt)