Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/AUTHORS +19 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/INSTALLER +1 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/LICENSE +29 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/METADATA +98 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/RECORD +29 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/WHEEL +6 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/top_level.txt +1 -0
- .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/zip-safe +1 -0
- .venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/METADATA +308 -0
- .venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/top_level.txt +1 -0
- .venv/lib/python3.11/site-packages/nvidia_cuda_cupti_cu12-12.4.127.dist-info/METADATA +35 -0
- .venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/LICENSE +22 -0
- .venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/METADATA +586 -0
- .venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/WHEEL +5 -0
- .venv/lib/python3.11/site-packages/torchvision/__init__.py +105 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/_internally_replaced_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/_meta_registrations.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/extension.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/__pycache__/version.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/_internally_replaced_utils.py +50 -0
- .venv/lib/python3.11/site-packages/torchvision/_meta_registrations.py +225 -0
- .venv/lib/python3.11/site-packages/torchvision/_utils.py +32 -0
- .venv/lib/python3.11/site-packages/torchvision/extension.py +92 -0
- .venv/lib/python3.11/site-packages/torchvision/models/alexnet.py +119 -0
- .venv/lib/python3.11/site-packages/torchvision/models/convnext.py +414 -0
- .venv/lib/python3.11/site-packages/torchvision/models/densenet.py +448 -0
- .venv/lib/python3.11/site-packages/torchvision/models/efficientnet.py +1131 -0
- .venv/lib/python3.11/site-packages/torchvision/models/googlenet.py +345 -0
- .venv/lib/python3.11/site-packages/torchvision/models/maxvit.py +833 -0
- .venv/lib/python3.11/site-packages/torchvision/models/mobilenetv3.py +423 -0
- .venv/lib/python3.11/site-packages/torchvision/models/squeezenet.py +223 -0
- .venv/lib/python3.11/site-packages/torchvision/models/vgg.py +511 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__init__.py +73 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/boxes.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/drop_block.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/misc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/poolers.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/AUTHORS
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Original author of astor/codegen.py:
|
| 2 |
+
* Armin Ronacher <armin.ronacher@active-4.com>
|
| 3 |
+
|
| 4 |
+
And with some modifications based on Armin's code:
|
| 5 |
+
* Paul Dubs <paul.dubs@gmail.com>
|
| 6 |
+
|
| 7 |
+
* Berker Peksag <berker.peksag@gmail.com>
|
| 8 |
+
* Patrick Maupin <pmaupin@gmail.com>
|
| 9 |
+
* Abhishek L <abhishek.lekshmanan@gmail.com>
|
| 10 |
+
* Bob Tolbert <bob@eyesopen.com>
|
| 11 |
+
* Whyzgeek <whyzgeek@gmail.com>
|
| 12 |
+
* Zack M. Davis <code@zackmdavis.net>
|
| 13 |
+
* Ryan Gonzalez <rymg19@gmail.com>
|
| 14 |
+
* Lenny Truong <leonardtruong@protonmail.com>
|
| 15 |
+
* Radomír Bosák <radomir.bosak@gmail.com>
|
| 16 |
+
* Kodi Arfer <git@arfer.net>
|
| 17 |
+
* Felix Yan <felixonmars@archlinux.org>
|
| 18 |
+
* Chris Rink <chrisrink10@gmail.com>
|
| 19 |
+
* Batuhan Taskaya <batuhanosmantaskaya@gmail.com>
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2012, Patrick Maupin
|
| 2 |
+
Copyright (c) 2013, Berker Peksag
|
| 3 |
+
Copyright (c) 2008, Armin Ronacher
|
| 4 |
+
All rights reserved.
|
| 5 |
+
|
| 6 |
+
Redistribution and use in source and binary forms, with or without modification,
|
| 7 |
+
are permitted provided that the following conditions are met:
|
| 8 |
+
|
| 9 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
list of conditions and the following disclaimer.
|
| 11 |
+
|
| 12 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
this list of conditions and the following disclaimer in the documentation and/or
|
| 14 |
+
other materials provided with the distribution.
|
| 15 |
+
|
| 16 |
+
3. Neither the name of the copyright holder nor the names of its contributors
|
| 17 |
+
may be used to endorse or promote products derived from this software without
|
| 18 |
+
specific prior written permission.
|
| 19 |
+
|
| 20 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 21 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 22 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
| 24 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 25 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 26 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 27 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 28 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 29 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/METADATA
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: astor
|
| 3 |
+
Version: 0.8.1
|
| 4 |
+
Summary: Read/rewrite/write Python ASTs
|
| 5 |
+
Home-page: https://github.com/berkerpeksag/astor
|
| 6 |
+
Author: Patrick Maupin
|
| 7 |
+
Author-email: pmaupin@gmail.com
|
| 8 |
+
License: BSD-3-Clause
|
| 9 |
+
Keywords: ast,codegen,PEP 8
|
| 10 |
+
Platform: Independent
|
| 11 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 12 |
+
Classifier: Environment :: Console
|
| 13 |
+
Classifier: Intended Audience :: Developers
|
| 14 |
+
Classifier: License :: OSI Approved :: BSD License
|
| 15 |
+
Classifier: Operating System :: OS Independent
|
| 16 |
+
Classifier: Programming Language :: Python
|
| 17 |
+
Classifier: Programming Language :: Python :: 2
|
| 18 |
+
Classifier: Programming Language :: Python :: 2.7
|
| 19 |
+
Classifier: Programming Language :: Python :: 3
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.4
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.5
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.6
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 24 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 25 |
+
Classifier: Programming Language :: Python :: Implementation
|
| 26 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
| 27 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
| 28 |
+
Classifier: Topic :: Software Development :: Code Generators
|
| 29 |
+
Classifier: Topic :: Software Development :: Compilers
|
| 30 |
+
Requires-Python: !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7
|
| 31 |
+
|
| 32 |
+
=============================
|
| 33 |
+
astor -- AST observe/rewrite
|
| 34 |
+
=============================
|
| 35 |
+
|
| 36 |
+
:PyPI: https://pypi.org/project/astor/
|
| 37 |
+
:Documentation: https://astor.readthedocs.io
|
| 38 |
+
:Source: https://github.com/berkerpeksag/astor
|
| 39 |
+
:License: 3-clause BSD
|
| 40 |
+
:Build status:
|
| 41 |
+
.. image:: https://secure.travis-ci.org/berkerpeksag/astor.svg
|
| 42 |
+
:alt: Travis CI
|
| 43 |
+
:target: https://travis-ci.org/berkerpeksag/astor/
|
| 44 |
+
|
| 45 |
+
astor is designed to allow easy manipulation of Python source via the AST.
|
| 46 |
+
|
| 47 |
+
There are some other similar libraries, but astor focuses on the following areas:
|
| 48 |
+
|
| 49 |
+
- Round-trip an AST back to Python [1]_:
|
| 50 |
+
|
| 51 |
+
- Modified AST doesn't need linenumbers, ctx, etc. or otherwise
|
| 52 |
+
be directly compileable for the round-trip to work.
|
| 53 |
+
- Easy to read generated code as, well, code
|
| 54 |
+
- Can round-trip two different source trees to compare for functional
|
| 55 |
+
differences, using the astor.rtrip tool (for example, after PEP8 edits).
|
| 56 |
+
|
| 57 |
+
- Dump pretty-printing of AST
|
| 58 |
+
|
| 59 |
+
- Harder to read than round-tripped code, but more accurate to figure out what
|
| 60 |
+
is going on.
|
| 61 |
+
|
| 62 |
+
- Easier to read than dump from built-in AST module
|
| 63 |
+
|
| 64 |
+
- Non-recursive treewalk
|
| 65 |
+
|
| 66 |
+
- Sometimes you want a recursive treewalk (and astor supports that, starting
|
| 67 |
+
at any node on the tree), but sometimes you don't need to do that. astor
|
| 68 |
+
doesn't require you to explicitly visit sub-nodes unless you want to:
|
| 69 |
+
|
| 70 |
+
- You can add code that executes before a node's children are visited, and/or
|
| 71 |
+
- You can add code that executes after a node's children are visited, and/or
|
| 72 |
+
- You can add code that executes and keeps the node's children from being
|
| 73 |
+
visited (and optionally visit them yourself via a recursive call)
|
| 74 |
+
|
| 75 |
+
- Write functions to access the tree based on object names and/or attribute names
|
| 76 |
+
- Enjoy easy access to parent node(s) for tree rewriting
|
| 77 |
+
|
| 78 |
+
.. [1]
|
| 79 |
+
The decompilation back to Python is based on code originally written
|
| 80 |
+
by Armin Ronacher. Armin's code was well-structured, but failed on
|
| 81 |
+
some obscure corner cases of the Python language (and even more corner
|
| 82 |
+
cases when the AST changed on different versions of Python), and its
|
| 83 |
+
output arguably had cosmetic issues -- for example, it produced
|
| 84 |
+
parentheses even in some cases where they were not needed, to
|
| 85 |
+
avoid having to reason about precedence.
|
| 86 |
+
|
| 87 |
+
Other derivatives of Armin's code are floating around, and typically
|
| 88 |
+
have fixes for a few corner cases that happened to be noticed by the
|
| 89 |
+
maintainers, but most of them have not been tested as thoroughly as
|
| 90 |
+
astor. One exception may be the version of codegen
|
| 91 |
+
`maintained at github by CensoredUsername`__. This has been tested
|
| 92 |
+
to work properly on Python 2.7 using astor's test suite, and, as it
|
| 93 |
+
is a single source file, it may be easier to drop into some applications
|
| 94 |
+
that do not require astor's other features or Python 3.x compatibility.
|
| 95 |
+
|
| 96 |
+
__ https://github.com/CensoredUsername/codegen
|
| 97 |
+
|
| 98 |
+
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/RECORD
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
astor-0.8.1.dist-info/AUTHORS,sha256=dy5MQIMINxY79YbaRR19C_CNAgHe3tcuvESs7ypxKQc,679
|
| 2 |
+
astor-0.8.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 3 |
+
astor-0.8.1.dist-info/LICENSE,sha256=zkHq_C78AY2cfJahx3lmgkbHfbEaE544ifNH9GSmG50,1554
|
| 4 |
+
astor-0.8.1.dist-info/METADATA,sha256=0nH_-dzD0tPZUB4Hs5o-OOEuId9lteVELQPI5hG0oKo,4235
|
| 5 |
+
astor-0.8.1.dist-info/RECORD,,
|
| 6 |
+
astor-0.8.1.dist-info/WHEEL,sha256=8zNYZbwQSXoB9IfXOjPfeNwvAsALAjffgk27FqvCWbo,110
|
| 7 |
+
astor-0.8.1.dist-info/top_level.txt,sha256=M5xfrbiL9-EIlOb1h2T8s6gFbV3b9AbwgI0ARzaRyaY,6
|
| 8 |
+
astor-0.8.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
| 9 |
+
astor/VERSION,sha256=qvZyHcN8QLQjOsz8CB8ld2_zvR0qS51c6nYNHCz4ZmU,6
|
| 10 |
+
astor/__init__.py,sha256=C9rmH4v9K7pkIk3eDuVRhqO5wULt3B42copNJsEw8rw,2291
|
| 11 |
+
astor/__pycache__/__init__.cpython-311.pyc,,
|
| 12 |
+
astor/__pycache__/code_gen.cpython-311.pyc,,
|
| 13 |
+
astor/__pycache__/codegen.cpython-311.pyc,,
|
| 14 |
+
astor/__pycache__/file_util.cpython-311.pyc,,
|
| 15 |
+
astor/__pycache__/node_util.cpython-311.pyc,,
|
| 16 |
+
astor/__pycache__/op_util.cpython-311.pyc,,
|
| 17 |
+
astor/__pycache__/rtrip.cpython-311.pyc,,
|
| 18 |
+
astor/__pycache__/source_repr.cpython-311.pyc,,
|
| 19 |
+
astor/__pycache__/string_repr.cpython-311.pyc,,
|
| 20 |
+
astor/__pycache__/tree_walk.cpython-311.pyc,,
|
| 21 |
+
astor/code_gen.py,sha256=0KAimfyV8pIPXQx6s_NyPSXRhAxMLWXbCPEQuCTpxac,32032
|
| 22 |
+
astor/codegen.py,sha256=lTqdJWMK4EAJ1wxDw2XR-MLyHJmvbV1_Q5QLj9naE_g,204
|
| 23 |
+
astor/file_util.py,sha256=BETsKYg8UiKoZNswRkirzPSZWgku41dRzZC7T5X3_F4,3268
|
| 24 |
+
astor/node_util.py,sha256=WEWMUMSfHtLwgx54nMkc2APLV573iOPhqPag4gIbhVQ,6542
|
| 25 |
+
astor/op_util.py,sha256=GGcgYqa3DFOAaoSt7TTu46VUhe1J13dO14-SQTRXRYI,3191
|
| 26 |
+
astor/rtrip.py,sha256=AlvQvsUuUZ8zxvRFpWF_Fsv4-NksPB23rvVkTrkvef8,6741
|
| 27 |
+
astor/source_repr.py,sha256=1lj4jakkrcGDRoo-BIRZDszQ8gukdeLR_fmvGqBrP-U,7373
|
| 28 |
+
astor/string_repr.py,sha256=YeC_DVeIJdPElqjgzzhPFheQsz_QjMEW_SLODFvEsIA,2917
|
| 29 |
+
astor/tree_walk.py,sha256=fJaw54GgTg4NTRJLVRl2XSnfFOG9GdjOUlI6ZChLOb8,6020
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: bdist_wheel (0.33.6)
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py2-none-any
|
| 5 |
+
Tag: py3-none-any
|
| 6 |
+
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
astor
|
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/zip-safe
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
.venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/METADATA
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: huggingface-hub
|
| 3 |
+
Version: 0.28.1
|
| 4 |
+
Summary: Client library to download and publish models, datasets and other repos on the huggingface.co hub
|
| 5 |
+
Home-page: https://github.com/huggingface/huggingface_hub
|
| 6 |
+
Author: Hugging Face, Inc.
|
| 7 |
+
Author-email: julien@huggingface.co
|
| 8 |
+
License: Apache
|
| 9 |
+
Keywords: model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models
|
| 10 |
+
Platform: UNKNOWN
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Education
|
| 13 |
+
Classifier: Intended Audience :: Science/Research
|
| 14 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 15 |
+
Classifier: Operating System :: OS Independent
|
| 16 |
+
Classifier: Programming Language :: Python :: 3
|
| 17 |
+
Classifier: Programming Language :: Python :: 3 :: Only
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 24 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 25 |
+
Requires-Python: >=3.8.0
|
| 26 |
+
Description-Content-Type: text/markdown
|
| 27 |
+
License-File: LICENSE
|
| 28 |
+
Requires-Dist: filelock
|
| 29 |
+
Requires-Dist: fsspec>=2023.5.0
|
| 30 |
+
Requires-Dist: packaging>=20.9
|
| 31 |
+
Requires-Dist: pyyaml>=5.1
|
| 32 |
+
Requires-Dist: requests
|
| 33 |
+
Requires-Dist: tqdm>=4.42.1
|
| 34 |
+
Requires-Dist: typing-extensions>=3.7.4.3
|
| 35 |
+
Provides-Extra: all
|
| 36 |
+
Requires-Dist: InquirerPy==0.3.4; extra == "all"
|
| 37 |
+
Requires-Dist: aiohttp; extra == "all"
|
| 38 |
+
Requires-Dist: jedi; extra == "all"
|
| 39 |
+
Requires-Dist: Jinja2; extra == "all"
|
| 40 |
+
Requires-Dist: pytest<8.2.2,>=8.1.1; extra == "all"
|
| 41 |
+
Requires-Dist: pytest-cov; extra == "all"
|
| 42 |
+
Requires-Dist: pytest-env; extra == "all"
|
| 43 |
+
Requires-Dist: pytest-xdist; extra == "all"
|
| 44 |
+
Requires-Dist: pytest-vcr; extra == "all"
|
| 45 |
+
Requires-Dist: pytest-asyncio; extra == "all"
|
| 46 |
+
Requires-Dist: pytest-rerunfailures; extra == "all"
|
| 47 |
+
Requires-Dist: pytest-mock; extra == "all"
|
| 48 |
+
Requires-Dist: urllib3<2.0; extra == "all"
|
| 49 |
+
Requires-Dist: soundfile; extra == "all"
|
| 50 |
+
Requires-Dist: Pillow; extra == "all"
|
| 51 |
+
Requires-Dist: gradio>=4.0.0; extra == "all"
|
| 52 |
+
Requires-Dist: numpy; extra == "all"
|
| 53 |
+
Requires-Dist: fastapi; extra == "all"
|
| 54 |
+
Requires-Dist: ruff>=0.9.0; extra == "all"
|
| 55 |
+
Requires-Dist: mypy==1.5.1; extra == "all"
|
| 56 |
+
Requires-Dist: libcst==1.4.0; extra == "all"
|
| 57 |
+
Requires-Dist: typing-extensions>=4.8.0; extra == "all"
|
| 58 |
+
Requires-Dist: types-PyYAML; extra == "all"
|
| 59 |
+
Requires-Dist: types-requests; extra == "all"
|
| 60 |
+
Requires-Dist: types-simplejson; extra == "all"
|
| 61 |
+
Requires-Dist: types-toml; extra == "all"
|
| 62 |
+
Requires-Dist: types-tqdm; extra == "all"
|
| 63 |
+
Requires-Dist: types-urllib3; extra == "all"
|
| 64 |
+
Provides-Extra: cli
|
| 65 |
+
Requires-Dist: InquirerPy==0.3.4; extra == "cli"
|
| 66 |
+
Provides-Extra: dev
|
| 67 |
+
Requires-Dist: InquirerPy==0.3.4; extra == "dev"
|
| 68 |
+
Requires-Dist: aiohttp; extra == "dev"
|
| 69 |
+
Requires-Dist: jedi; extra == "dev"
|
| 70 |
+
Requires-Dist: Jinja2; extra == "dev"
|
| 71 |
+
Requires-Dist: pytest<8.2.2,>=8.1.1; extra == "dev"
|
| 72 |
+
Requires-Dist: pytest-cov; extra == "dev"
|
| 73 |
+
Requires-Dist: pytest-env; extra == "dev"
|
| 74 |
+
Requires-Dist: pytest-xdist; extra == "dev"
|
| 75 |
+
Requires-Dist: pytest-vcr; extra == "dev"
|
| 76 |
+
Requires-Dist: pytest-asyncio; extra == "dev"
|
| 77 |
+
Requires-Dist: pytest-rerunfailures; extra == "dev"
|
| 78 |
+
Requires-Dist: pytest-mock; extra == "dev"
|
| 79 |
+
Requires-Dist: urllib3<2.0; extra == "dev"
|
| 80 |
+
Requires-Dist: soundfile; extra == "dev"
|
| 81 |
+
Requires-Dist: Pillow; extra == "dev"
|
| 82 |
+
Requires-Dist: gradio>=4.0.0; extra == "dev"
|
| 83 |
+
Requires-Dist: numpy; extra == "dev"
|
| 84 |
+
Requires-Dist: fastapi; extra == "dev"
|
| 85 |
+
Requires-Dist: ruff>=0.9.0; extra == "dev"
|
| 86 |
+
Requires-Dist: mypy==1.5.1; extra == "dev"
|
| 87 |
+
Requires-Dist: libcst==1.4.0; extra == "dev"
|
| 88 |
+
Requires-Dist: typing-extensions>=4.8.0; extra == "dev"
|
| 89 |
+
Requires-Dist: types-PyYAML; extra == "dev"
|
| 90 |
+
Requires-Dist: types-requests; extra == "dev"
|
| 91 |
+
Requires-Dist: types-simplejson; extra == "dev"
|
| 92 |
+
Requires-Dist: types-toml; extra == "dev"
|
| 93 |
+
Requires-Dist: types-tqdm; extra == "dev"
|
| 94 |
+
Requires-Dist: types-urllib3; extra == "dev"
|
| 95 |
+
Provides-Extra: fastai
|
| 96 |
+
Requires-Dist: toml; extra == "fastai"
|
| 97 |
+
Requires-Dist: fastai>=2.4; extra == "fastai"
|
| 98 |
+
Requires-Dist: fastcore>=1.3.27; extra == "fastai"
|
| 99 |
+
Provides-Extra: hf_transfer
|
| 100 |
+
Requires-Dist: hf-transfer>=0.1.4; extra == "hf-transfer"
|
| 101 |
+
Provides-Extra: inference
|
| 102 |
+
Requires-Dist: aiohttp; extra == "inference"
|
| 103 |
+
Provides-Extra: quality
|
| 104 |
+
Requires-Dist: ruff>=0.9.0; extra == "quality"
|
| 105 |
+
Requires-Dist: mypy==1.5.1; extra == "quality"
|
| 106 |
+
Requires-Dist: libcst==1.4.0; extra == "quality"
|
| 107 |
+
Provides-Extra: tensorflow
|
| 108 |
+
Requires-Dist: tensorflow; extra == "tensorflow"
|
| 109 |
+
Requires-Dist: pydot; extra == "tensorflow"
|
| 110 |
+
Requires-Dist: graphviz; extra == "tensorflow"
|
| 111 |
+
Provides-Extra: tensorflow-testing
|
| 112 |
+
Requires-Dist: tensorflow; extra == "tensorflow-testing"
|
| 113 |
+
Requires-Dist: keras<3.0; extra == "tensorflow-testing"
|
| 114 |
+
Provides-Extra: testing
|
| 115 |
+
Requires-Dist: InquirerPy==0.3.4; extra == "testing"
|
| 116 |
+
Requires-Dist: aiohttp; extra == "testing"
|
| 117 |
+
Requires-Dist: jedi; extra == "testing"
|
| 118 |
+
Requires-Dist: Jinja2; extra == "testing"
|
| 119 |
+
Requires-Dist: pytest<8.2.2,>=8.1.1; extra == "testing"
|
| 120 |
+
Requires-Dist: pytest-cov; extra == "testing"
|
| 121 |
+
Requires-Dist: pytest-env; extra == "testing"
|
| 122 |
+
Requires-Dist: pytest-xdist; extra == "testing"
|
| 123 |
+
Requires-Dist: pytest-vcr; extra == "testing"
|
| 124 |
+
Requires-Dist: pytest-asyncio; extra == "testing"
|
| 125 |
+
Requires-Dist: pytest-rerunfailures; extra == "testing"
|
| 126 |
+
Requires-Dist: pytest-mock; extra == "testing"
|
| 127 |
+
Requires-Dist: urllib3<2.0; extra == "testing"
|
| 128 |
+
Requires-Dist: soundfile; extra == "testing"
|
| 129 |
+
Requires-Dist: Pillow; extra == "testing"
|
| 130 |
+
Requires-Dist: gradio>=4.0.0; extra == "testing"
|
| 131 |
+
Requires-Dist: numpy; extra == "testing"
|
| 132 |
+
Requires-Dist: fastapi; extra == "testing"
|
| 133 |
+
Provides-Extra: torch
|
| 134 |
+
Requires-Dist: torch; extra == "torch"
|
| 135 |
+
Requires-Dist: safetensors[torch]; extra == "torch"
|
| 136 |
+
Provides-Extra: typing
|
| 137 |
+
Requires-Dist: typing-extensions>=4.8.0; extra == "typing"
|
| 138 |
+
Requires-Dist: types-PyYAML; extra == "typing"
|
| 139 |
+
Requires-Dist: types-requests; extra == "typing"
|
| 140 |
+
Requires-Dist: types-simplejson; extra == "typing"
|
| 141 |
+
Requires-Dist: types-toml; extra == "typing"
|
| 142 |
+
Requires-Dist: types-tqdm; extra == "typing"
|
| 143 |
+
Requires-Dist: types-urllib3; extra == "typing"
|
| 144 |
+
|
| 145 |
+
<p align="center">
|
| 146 |
+
<picture>
|
| 147 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub-dark.svg">
|
| 148 |
+
<source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg">
|
| 149 |
+
<img alt="huggingface_hub library logo" src="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg" width="352" height="59" style="max-width: 100%;">
|
| 150 |
+
</picture>
|
| 151 |
+
<br/>
|
| 152 |
+
<br/>
|
| 153 |
+
</p>
|
| 154 |
+
|
| 155 |
+
<p align="center">
|
| 156 |
+
<i>The official Python client for the Huggingface Hub.</i>
|
| 157 |
+
</p>
|
| 158 |
+
|
| 159 |
+
<p align="center">
|
| 160 |
+
<a href="https://huggingface.co/docs/huggingface_hub/en/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/huggingface_hub/index.svg?down_color=red&down_message=offline&up_message=online&label=doc"></a>
|
| 161 |
+
<a href="https://github.com/huggingface/huggingface_hub/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/huggingface_hub.svg"></a>
|
| 162 |
+
<a href="https://github.com/huggingface/huggingface_hub"><img alt="PyPi version" src="https://img.shields.io/pypi/pyversions/huggingface_hub.svg"></a>
|
| 163 |
+
<a href="https://pypi.org/project/huggingface-hub"><img alt="PyPI - Downloads" src="https://img.shields.io/pypi/dm/huggingface_hub"></a>
|
| 164 |
+
<a href="https://codecov.io/gh/huggingface/huggingface_hub"><img alt="Code coverage" src="https://codecov.io/gh/huggingface/huggingface_hub/branch/main/graph/badge.svg?token=RXP95LE2XL"></a>
|
| 165 |
+
</p>
|
| 166 |
+
|
| 167 |
+
<h4 align="center">
|
| 168 |
+
<p>
|
| 169 |
+
<b>English</b> |
|
| 170 |
+
<a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_de.md">Deutsch</a> |
|
| 171 |
+
<a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_hi.md">हिंदी</a> |
|
| 172 |
+
<a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_ko.md">한국어</a> |
|
| 173 |
+
<a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_cn.md">中文(简体)</a>
|
| 174 |
+
<p>
|
| 175 |
+
</h4>
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
**Documentation**: <a href="https://hf.co/docs/huggingface_hub" target="_blank">https://hf.co/docs/huggingface_hub</a>
|
| 180 |
+
|
| 181 |
+
**Source Code**: <a href="https://github.com/huggingface/huggingface_hub" target="_blank">https://github.com/huggingface/huggingface_hub</a>
|
| 182 |
+
|
| 183 |
+
---
|
| 184 |
+
|
| 185 |
+
## Welcome to the huggingface_hub library
|
| 186 |
+
|
| 187 |
+
The `huggingface_hub` library allows you to interact with the [Hugging Face Hub](https://huggingface.co/), a platform democratizing open-source Machine Learning for creators and collaborators. Discover pre-trained models and datasets for your projects or play with the thousands of machine learning apps hosted on the Hub. You can also create and share your own models, datasets and demos with the community. The `huggingface_hub` library provides a simple way to do all these things with Python.
|
| 188 |
+
|
| 189 |
+
## Key features
|
| 190 |
+
|
| 191 |
+
- [Download files](https://huggingface.co/docs/huggingface_hub/en/guides/download) from the Hub.
|
| 192 |
+
- [Upload files](https://huggingface.co/docs/huggingface_hub/en/guides/upload) to the Hub.
|
| 193 |
+
- [Manage your repositories](https://huggingface.co/docs/huggingface_hub/en/guides/repository).
|
| 194 |
+
- [Run Inference](https://huggingface.co/docs/huggingface_hub/en/guides/inference) on deployed models.
|
| 195 |
+
- [Search](https://huggingface.co/docs/huggingface_hub/en/guides/search) for models, datasets and Spaces.
|
| 196 |
+
- [Share Model Cards](https://huggingface.co/docs/huggingface_hub/en/guides/model-cards) to document your models.
|
| 197 |
+
- [Engage with the community](https://huggingface.co/docs/huggingface_hub/en/guides/community) through PRs and comments.
|
| 198 |
+
|
| 199 |
+
## Installation
|
| 200 |
+
|
| 201 |
+
Install the `huggingface_hub` package with [pip](https://pypi.org/project/huggingface-hub/):
|
| 202 |
+
|
| 203 |
+
```bash
|
| 204 |
+
pip install huggingface_hub
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
If you prefer, you can also install it with [conda](https://huggingface.co/docs/huggingface_hub/en/installation#install-with-conda).
|
| 208 |
+
|
| 209 |
+
In order to keep the package minimal by default, `huggingface_hub` comes with optional dependencies useful for some use cases. For example, if you want have a complete experience for Inference, run:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
pip install huggingface_hub[inference]
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
To learn more installation and optional dependencies, check out the [installation guide](https://huggingface.co/docs/huggingface_hub/en/installation).
|
| 216 |
+
|
| 217 |
+
## Quick start
|
| 218 |
+
|
| 219 |
+
### Download files
|
| 220 |
+
|
| 221 |
+
Download a single file
|
| 222 |
+
|
| 223 |
+
```py
|
| 224 |
+
from huggingface_hub import hf_hub_download
|
| 225 |
+
|
| 226 |
+
hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json")
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
Or an entire repository
|
| 230 |
+
|
| 231 |
+
```py
|
| 232 |
+
from huggingface_hub import snapshot_download
|
| 233 |
+
|
| 234 |
+
snapshot_download("stabilityai/stable-diffusion-2-1")
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
Files will be downloaded in a local cache folder. More details in [this guide](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache).
|
| 238 |
+
|
| 239 |
+
### Login
|
| 240 |
+
|
| 241 |
+
The Hugging Face Hub uses tokens to authenticate applications (see [docs](https://huggingface.co/docs/hub/security-tokens)). To log in your machine, run the following CLI:
|
| 242 |
+
|
| 243 |
+
```bash
|
| 244 |
+
huggingface-cli login
|
| 245 |
+
# or using an environment variable
|
| 246 |
+
huggingface-cli login --token $HUGGINGFACE_TOKEN
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
### Create a repository
|
| 250 |
+
|
| 251 |
+
```py
|
| 252 |
+
from huggingface_hub import create_repo
|
| 253 |
+
|
| 254 |
+
create_repo(repo_id="super-cool-model")
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
### Upload files
|
| 258 |
+
|
| 259 |
+
Upload a single file
|
| 260 |
+
|
| 261 |
+
```py
|
| 262 |
+
from huggingface_hub import upload_file
|
| 263 |
+
|
| 264 |
+
upload_file(
|
| 265 |
+
path_or_fileobj="/home/lysandre/dummy-test/README.md",
|
| 266 |
+
path_in_repo="README.md",
|
| 267 |
+
repo_id="lysandre/test-model",
|
| 268 |
+
)
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
Or an entire folder
|
| 272 |
+
|
| 273 |
+
```py
|
| 274 |
+
from huggingface_hub import upload_folder
|
| 275 |
+
|
| 276 |
+
upload_folder(
|
| 277 |
+
folder_path="/path/to/local/space",
|
| 278 |
+
repo_id="username/my-cool-space",
|
| 279 |
+
repo_type="space",
|
| 280 |
+
)
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
For details in the [upload guide](https://huggingface.co/docs/huggingface_hub/en/guides/upload).
|
| 284 |
+
|
| 285 |
+
## Integrating to the Hub.
|
| 286 |
+
|
| 287 |
+
We're partnering with cool open source ML libraries to provide free model hosting and versioning. You can find the existing integrations [here](https://huggingface.co/docs/hub/libraries).
|
| 288 |
+
|
| 289 |
+
The advantages are:
|
| 290 |
+
|
| 291 |
+
- Free model or dataset hosting for libraries and their users.
|
| 292 |
+
- Built-in file versioning, even with very large files, thanks to a git-based approach.
|
| 293 |
+
- Serverless inference API for all models publicly available.
|
| 294 |
+
- In-browser widgets to play with the uploaded models.
|
| 295 |
+
- Anyone can upload a new model for your library, they just need to add the corresponding tag for the model to be discoverable.
|
| 296 |
+
- Fast downloads! We use Cloudfront (a CDN) to geo-replicate downloads so they're blazing fast from anywhere on the globe.
|
| 297 |
+
- Usage stats and more features to come.
|
| 298 |
+
|
| 299 |
+
If you would like to integrate your library, feel free to open an issue to begin the discussion. We wrote a [step-by-step guide](https://huggingface.co/docs/hub/adding-a-library) with ❤️ showing how to do this integration.
|
| 300 |
+
|
| 301 |
+
## Contributions (feature requests, bugs, etc.) are super welcome 💙💚💛💜🧡❤️
|
| 302 |
+
|
| 303 |
+
Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community.
|
| 304 |
+
Answering questions, helping others, reaching out and improving the documentations are immensely valuable to the community.
|
| 305 |
+
We wrote a [contribution guide](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) to summarize
|
| 306 |
+
how to get started to contribute to this repository.
|
| 307 |
+
|
| 308 |
+
|
.venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub
|
.venv/lib/python3.11/site-packages/nvidia_cuda_cupti_cu12-12.4.127.dist-info/METADATA
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: nvidia-cuda-cupti-cu12
|
| 3 |
+
Version: 12.4.127
|
| 4 |
+
Summary: CUDA profiling tools runtime libs.
|
| 5 |
+
Home-page: https://developer.nvidia.com/cuda-zone
|
| 6 |
+
Author: Nvidia CUDA Installer Team
|
| 7 |
+
Author-email: cuda_installer@nvidia.com
|
| 8 |
+
License: NVIDIA Proprietary Software
|
| 9 |
+
Keywords: cuda,nvidia,runtime,machine learning,deep learning
|
| 10 |
+
Classifier: Development Status :: 4 - Beta
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: Intended Audience :: Education
|
| 13 |
+
Classifier: Intended Audience :: Science/Research
|
| 14 |
+
Classifier: License :: Other/Proprietary License
|
| 15 |
+
Classifier: Natural Language :: English
|
| 16 |
+
Classifier: Programming Language :: Python :: 3
|
| 17 |
+
Classifier: Programming Language :: Python :: 3.5
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.6
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 24 |
+
Classifier: Programming Language :: Python :: 3 :: Only
|
| 25 |
+
Classifier: Topic :: Scientific/Engineering
|
| 26 |
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
| 27 |
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
| 28 |
+
Classifier: Topic :: Software Development
|
| 29 |
+
Classifier: Topic :: Software Development :: Libraries
|
| 30 |
+
Classifier: Operating System :: Microsoft :: Windows
|
| 31 |
+
Classifier: Operating System :: POSIX :: Linux
|
| 32 |
+
Requires-Python: >=3
|
| 33 |
+
License-File: License.txt
|
| 34 |
+
|
| 35 |
+
Provides libraries to enable third party tools using GPU profiling APIs.
|
.venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/LICENSE
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
The MIT License (MIT)
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2015 Radim Řehůřek
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
.venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/METADATA
ADDED
|
@@ -0,0 +1,586 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: smart-open
|
| 3 |
+
Version: 7.1.0
|
| 4 |
+
Summary: Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)
|
| 5 |
+
Home-page: https://github.com/piskvorky/smart_open
|
| 6 |
+
Author: Radim Rehurek
|
| 7 |
+
Author-email: me@radimrehurek.com
|
| 8 |
+
Maintainer: Radim Rehurek
|
| 9 |
+
Maintainer-email: me@radimrehurek.com
|
| 10 |
+
License: MIT
|
| 11 |
+
Download-URL: http://pypi.python.org/pypi/smart_open
|
| 12 |
+
Keywords: file streaming,s3,hdfs,gcs,azure blob storage
|
| 13 |
+
Platform: any
|
| 14 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 15 |
+
Classifier: Environment :: Console
|
| 16 |
+
Classifier: Intended Audience :: Developers
|
| 17 |
+
Classifier: License :: OSI Approved :: MIT License
|
| 18 |
+
Classifier: Operating System :: OS Independent
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 22 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 23 |
+
Classifier: Programming Language :: Python :: 3.11
|
| 24 |
+
Classifier: Programming Language :: Python :: 3.12
|
| 25 |
+
Classifier: Programming Language :: Python :: 3.13
|
| 26 |
+
Classifier: Topic :: System :: Distributed Computing
|
| 27 |
+
Classifier: Topic :: Database :: Front-Ends
|
| 28 |
+
Requires-Python: >=3.7,<4.0
|
| 29 |
+
Requires-Dist: wrapt
|
| 30 |
+
Provides-Extra: all
|
| 31 |
+
Requires-Dist: boto3; extra == "all"
|
| 32 |
+
Requires-Dist: google-cloud-storage>=2.6.0; extra == "all"
|
| 33 |
+
Requires-Dist: azure-storage-blob; extra == "all"
|
| 34 |
+
Requires-Dist: azure-common; extra == "all"
|
| 35 |
+
Requires-Dist: azure-core; extra == "all"
|
| 36 |
+
Requires-Dist: requests; extra == "all"
|
| 37 |
+
Requires-Dist: paramiko; extra == "all"
|
| 38 |
+
Requires-Dist: zstandard; extra == "all"
|
| 39 |
+
Provides-Extra: azure
|
| 40 |
+
Requires-Dist: azure-storage-blob; extra == "azure"
|
| 41 |
+
Requires-Dist: azure-common; extra == "azure"
|
| 42 |
+
Requires-Dist: azure-core; extra == "azure"
|
| 43 |
+
Provides-Extra: gcs
|
| 44 |
+
Requires-Dist: google-cloud-storage>=2.6.0; extra == "gcs"
|
| 45 |
+
Provides-Extra: http
|
| 46 |
+
Requires-Dist: requests; extra == "http"
|
| 47 |
+
Provides-Extra: s3
|
| 48 |
+
Requires-Dist: boto3; extra == "s3"
|
| 49 |
+
Provides-Extra: ssh
|
| 50 |
+
Requires-Dist: paramiko; extra == "ssh"
|
| 51 |
+
Provides-Extra: test
|
| 52 |
+
Requires-Dist: boto3; extra == "test"
|
| 53 |
+
Requires-Dist: google-cloud-storage>=2.6.0; extra == "test"
|
| 54 |
+
Requires-Dist: azure-storage-blob; extra == "test"
|
| 55 |
+
Requires-Dist: azure-common; extra == "test"
|
| 56 |
+
Requires-Dist: azure-core; extra == "test"
|
| 57 |
+
Requires-Dist: requests; extra == "test"
|
| 58 |
+
Requires-Dist: paramiko; extra == "test"
|
| 59 |
+
Requires-Dist: zstandard; extra == "test"
|
| 60 |
+
Requires-Dist: moto[server]; extra == "test"
|
| 61 |
+
Requires-Dist: responses; extra == "test"
|
| 62 |
+
Requires-Dist: pytest; extra == "test"
|
| 63 |
+
Requires-Dist: pytest-rerunfailures; extra == "test"
|
| 64 |
+
Requires-Dist: pytest-benchmark; extra == "test"
|
| 65 |
+
Requires-Dist: awscli; extra == "test"
|
| 66 |
+
Requires-Dist: pyopenssl; extra == "test"
|
| 67 |
+
Requires-Dist: numpy; extra == "test"
|
| 68 |
+
Provides-Extra: webhdfs
|
| 69 |
+
Requires-Dist: requests; extra == "webhdfs"
|
| 70 |
+
Provides-Extra: zst
|
| 71 |
+
Requires-Dist: zstandard; extra == "zst"
|
| 72 |
+
|
| 73 |
+
======================================================
|
| 74 |
+
smart_open — utils for streaming large files in Python
|
| 75 |
+
======================================================
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|License|_ |GHA|_ |Coveralls|_ |Downloads|_
|
| 79 |
+
|
| 80 |
+
.. |License| image:: https://img.shields.io/pypi/l/smart_open.svg
|
| 81 |
+
.. |GHA| image:: https://github.com/RaRe-Technologies/smart_open/workflows/Test/badge.svg
|
| 82 |
+
.. |Coveralls| image:: https://coveralls.io/repos/github/RaRe-Technologies/smart_open/badge.svg?branch=develop
|
| 83 |
+
.. |Downloads| image:: https://pepy.tech/badge/smart-open/month
|
| 84 |
+
.. _License: https://github.com/RaRe-Technologies/smart_open/blob/master/LICENSE
|
| 85 |
+
.. _GHA: https://github.com/RaRe-Technologies/smart_open/actions?query=workflow%3ATest
|
| 86 |
+
.. _Coveralls: https://coveralls.io/github/RaRe-Technologies/smart_open?branch=HEAD
|
| 87 |
+
.. _Downloads: https://pypi.org/project/smart-open/
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
What?
|
| 91 |
+
=====
|
| 92 |
+
|
| 93 |
+
``smart_open`` is a Python 3 library for **efficient streaming of very large files** from/to storages such as S3, GCS, Azure Blob Storage, HDFS, WebHDFS, HTTP, HTTPS, SFTP, or local filesystem. It supports transparent, on-the-fly (de-)compression for a variety of different formats.
|
| 94 |
+
|
| 95 |
+
``smart_open`` is a drop-in replacement for Python's built-in ``open()``: it can do anything ``open`` can (100% compatible, falls back to native ``open`` wherever possible), plus lots of nifty extra stuff on top.
|
| 96 |
+
|
| 97 |
+
**Python 2.7 is no longer supported. If you need Python 2.7, please use** `smart_open 1.10.1 <https://github.com/RaRe-Technologies/smart_open/releases/tag/1.10.0>`_, **the last version to support Python 2.**
|
| 98 |
+
|
| 99 |
+
Why?
|
| 100 |
+
====
|
| 101 |
+
|
| 102 |
+
Working with large remote files, for example using Amazon's `boto3 <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>`_ Python library, is a pain.
|
| 103 |
+
``boto3``'s ``Object.upload_fileobj()`` and ``Object.download_fileobj()`` methods require gotcha-prone boilerplate to use successfully, such as constructing file-like object wrappers.
|
| 104 |
+
``smart_open`` shields you from that. It builds on boto3 and other remote storage libraries, but offers a **clean unified Pythonic API**. The result is less code for you to write and fewer bugs to make.
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
How?
|
| 108 |
+
=====
|
| 109 |
+
|
| 110 |
+
``smart_open`` is well-tested, well-documented, and has a simple Pythonic API:
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
.. _doctools_before_examples:
|
| 114 |
+
|
| 115 |
+
.. code-block:: python
|
| 116 |
+
|
| 117 |
+
>>> from smart_open import open
|
| 118 |
+
>>>
|
| 119 |
+
>>> # stream lines from an S3 object
|
| 120 |
+
>>> for line in open('s3://commoncrawl/robots.txt'):
|
| 121 |
+
... print(repr(line))
|
| 122 |
+
... break
|
| 123 |
+
'User-Agent: *\n'
|
| 124 |
+
|
| 125 |
+
>>> # stream from/to compressed files, with transparent (de)compression:
|
| 126 |
+
>>> for line in open('smart_open/tests/test_data/1984.txt.gz', encoding='utf-8'):
|
| 127 |
+
... print(repr(line))
|
| 128 |
+
'It was a bright cold day in April, and the clocks were striking thirteen.\n'
|
| 129 |
+
'Winston Smith, his chin nuzzled into his breast in an effort to escape the vile\n'
|
| 130 |
+
'wind, slipped quickly through the glass doors of Victory Mansions, though not\n'
|
| 131 |
+
'quickly enough to prevent a swirl of gritty dust from entering along with him.\n'
|
| 132 |
+
|
| 133 |
+
>>> # can use context managers too:
|
| 134 |
+
>>> with open('smart_open/tests/test_data/1984.txt.gz') as fin:
|
| 135 |
+
... with open('smart_open/tests/test_data/1984.txt.bz2', 'w') as fout:
|
| 136 |
+
... for line in fin:
|
| 137 |
+
... fout.write(line)
|
| 138 |
+
74
|
| 139 |
+
80
|
| 140 |
+
78
|
| 141 |
+
79
|
| 142 |
+
|
| 143 |
+
>>> # can use any IOBase operations, like seek
|
| 144 |
+
>>> with open('s3://commoncrawl/robots.txt', 'rb') as fin:
|
| 145 |
+
... for line in fin:
|
| 146 |
+
... print(repr(line.decode('utf-8')))
|
| 147 |
+
... break
|
| 148 |
+
... offset = fin.seek(0) # seek to the beginning
|
| 149 |
+
... print(fin.read(4))
|
| 150 |
+
'User-Agent: *\n'
|
| 151 |
+
b'User'
|
| 152 |
+
|
| 153 |
+
>>> # stream from HTTP
|
| 154 |
+
>>> for line in open('http://example.com/index.html'):
|
| 155 |
+
... print(repr(line))
|
| 156 |
+
... break
|
| 157 |
+
'<!doctype html>\n'
|
| 158 |
+
|
| 159 |
+
.. _doctools_after_examples:
|
| 160 |
+
|
| 161 |
+
Other examples of URLs that ``smart_open`` accepts::
|
| 162 |
+
|
| 163 |
+
s3://my_bucket/my_key
|
| 164 |
+
s3://my_key:my_secret@my_bucket/my_key
|
| 165 |
+
s3://my_key:my_secret@my_server:my_port@my_bucket/my_key
|
| 166 |
+
gs://my_bucket/my_blob
|
| 167 |
+
azure://my_bucket/my_blob
|
| 168 |
+
hdfs:///path/file
|
| 169 |
+
hdfs://path/file
|
| 170 |
+
webhdfs://host:port/path/file
|
| 171 |
+
./local/path/file
|
| 172 |
+
~/local/path/file
|
| 173 |
+
local/path/file
|
| 174 |
+
./local/path/file.gz
|
| 175 |
+
file:///home/user/file
|
| 176 |
+
file:///home/user/file.bz2
|
| 177 |
+
[ssh|scp|sftp]://username@host//path/file
|
| 178 |
+
[ssh|scp|sftp]://username@host/path/file
|
| 179 |
+
[ssh|scp|sftp]://username:password@host/path/file
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
Documentation
|
| 183 |
+
=============
|
| 184 |
+
|
| 185 |
+
Installation
|
| 186 |
+
------------
|
| 187 |
+
|
| 188 |
+
``smart_open`` supports a wide range of storage solutions, including AWS S3, Google Cloud and Azure.
|
| 189 |
+
Each individual solution has its own dependencies.
|
| 190 |
+
By default, ``smart_open`` does not install any dependencies, in order to keep the installation size small.
|
| 191 |
+
You can install these dependencies explicitly using::
|
| 192 |
+
|
| 193 |
+
pip install smart_open[azure] # Install Azure deps
|
| 194 |
+
pip install smart_open[gcs] # Install GCS deps
|
| 195 |
+
pip install smart_open[s3] # Install S3 deps
|
| 196 |
+
|
| 197 |
+
Or, if you don't mind installing a large number of third party libraries, you can install all dependencies using::
|
| 198 |
+
|
| 199 |
+
pip install smart_open[all]
|
| 200 |
+
|
| 201 |
+
Be warned that this option increases the installation size significantly, e.g. over 100MB.
|
| 202 |
+
|
| 203 |
+
If you're upgrading from ``smart_open`` versions 2.x and below, please check out the `Migration Guide <MIGRATING_FROM_OLDER_VERSIONS.rst>`_.
|
| 204 |
+
|
| 205 |
+
Built-in help
|
| 206 |
+
-------------
|
| 207 |
+
|
| 208 |
+
For detailed API info, see the online help:
|
| 209 |
+
|
| 210 |
+
.. code-block:: python
|
| 211 |
+
|
| 212 |
+
help('smart_open')
|
| 213 |
+
|
| 214 |
+
or click `here <https://github.com/RaRe-Technologies/smart_open/blob/master/help.txt>`__ to view the help in your browser.
|
| 215 |
+
|
| 216 |
+
More examples
|
| 217 |
+
-------------
|
| 218 |
+
|
| 219 |
+
For the sake of simplicity, the examples below assume you have all the dependencies installed, i.e. you have done::
|
| 220 |
+
|
| 221 |
+
pip install smart_open[all]
|
| 222 |
+
|
| 223 |
+
.. code-block:: python
|
| 224 |
+
|
| 225 |
+
>>> import os, boto3
|
| 226 |
+
>>> from smart_open import open
|
| 227 |
+
>>>
|
| 228 |
+
>>> # stream content *into* S3 (write mode) using a custom session
|
| 229 |
+
>>> session = boto3.Session(
|
| 230 |
+
... aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
|
| 231 |
+
... aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
|
| 232 |
+
... )
|
| 233 |
+
>>> url = 's3://smart-open-py37-benchmark-results/test.txt'
|
| 234 |
+
>>> with open(url, 'wb', transport_params={'client': session.client('s3')}) as fout:
|
| 235 |
+
... bytes_written = fout.write(b'hello world!')
|
| 236 |
+
... print(bytes_written)
|
| 237 |
+
12
|
| 238 |
+
|
| 239 |
+
.. code-block:: python
|
| 240 |
+
|
| 241 |
+
# stream from HDFS
|
| 242 |
+
for line in open('hdfs://user/hadoop/my_file.txt', encoding='utf8'):
|
| 243 |
+
print(line)
|
| 244 |
+
|
| 245 |
+
# stream from WebHDFS
|
| 246 |
+
for line in open('webhdfs://host:port/user/hadoop/my_file.txt'):
|
| 247 |
+
print(line)
|
| 248 |
+
|
| 249 |
+
# stream content *into* HDFS (write mode):
|
| 250 |
+
with open('hdfs://host:port/user/hadoop/my_file.txt', 'wb') as fout:
|
| 251 |
+
fout.write(b'hello world')
|
| 252 |
+
|
| 253 |
+
# stream content *into* WebHDFS (write mode):
|
| 254 |
+
with open('webhdfs://host:port/user/hadoop/my_file.txt', 'wb') as fout:
|
| 255 |
+
fout.write(b'hello world')
|
| 256 |
+
|
| 257 |
+
# stream from a completely custom s3 server, like s3proxy:
|
| 258 |
+
for line in open('s3u://user:secret@host:port@mybucket/mykey.txt'):
|
| 259 |
+
print(line)
|
| 260 |
+
|
| 261 |
+
# Stream to Digital Ocean Spaces bucket providing credentials from boto3 profile
|
| 262 |
+
session = boto3.Session(profile_name='digitalocean')
|
| 263 |
+
client = session.client('s3', endpoint_url='https://ams3.digitaloceanspaces.com')
|
| 264 |
+
transport_params = {'client': client}
|
| 265 |
+
with open('s3://bucket/key.txt', 'wb', transport_params=transport_params) as fout:
|
| 266 |
+
fout.write(b'here we stand')
|
| 267 |
+
|
| 268 |
+
# stream from GCS
|
| 269 |
+
for line in open('gs://my_bucket/my_file.txt'):
|
| 270 |
+
print(line)
|
| 271 |
+
|
| 272 |
+
# stream content *into* GCS (write mode):
|
| 273 |
+
with open('gs://my_bucket/my_file.txt', 'wb') as fout:
|
| 274 |
+
fout.write(b'hello world')
|
| 275 |
+
|
| 276 |
+
# stream from Azure Blob Storage
|
| 277 |
+
connect_str = os.environ['AZURE_STORAGE_CONNECTION_STRING']
|
| 278 |
+
transport_params = {
|
| 279 |
+
'client': azure.storage.blob.BlobServiceClient.from_connection_string(connect_str),
|
| 280 |
+
}
|
| 281 |
+
for line in open('azure://mycontainer/myfile.txt', transport_params=transport_params):
|
| 282 |
+
print(line)
|
| 283 |
+
|
| 284 |
+
# stream content *into* Azure Blob Storage (write mode):
|
| 285 |
+
connect_str = os.environ['AZURE_STORAGE_CONNECTION_STRING']
|
| 286 |
+
transport_params = {
|
| 287 |
+
'client': azure.storage.blob.BlobServiceClient.from_connection_string(connect_str),
|
| 288 |
+
}
|
| 289 |
+
with open('azure://mycontainer/my_file.txt', 'wb', transport_params=transport_params) as fout:
|
| 290 |
+
fout.write(b'hello world')
|
| 291 |
+
|
| 292 |
+
Compression Handling
|
| 293 |
+
--------------------
|
| 294 |
+
|
| 295 |
+
The top-level `compression` parameter controls compression/decompression behavior when reading and writing.
|
| 296 |
+
The supported values for this parameter are:
|
| 297 |
+
|
| 298 |
+
- ``infer_from_extension`` (default behavior)
|
| 299 |
+
- ``disable``
|
| 300 |
+
- ``.gz``
|
| 301 |
+
- ``.bz2``
|
| 302 |
+
- ``.zst``
|
| 303 |
+
|
| 304 |
+
By default, ``smart_open`` determines the compression algorithm to use based on the file extension.
|
| 305 |
+
|
| 306 |
+
.. code-block:: python
|
| 307 |
+
|
| 308 |
+
>>> from smart_open import open, register_compressor
|
| 309 |
+
>>> with open('smart_open/tests/test_data/1984.txt.gz') as fin:
|
| 310 |
+
... print(fin.read(32))
|
| 311 |
+
It was a bright cold day in Apri
|
| 312 |
+
|
| 313 |
+
You can override this behavior to either disable compression, or explicitly specify the algorithm to use.
|
| 314 |
+
To disable compression:
|
| 315 |
+
|
| 316 |
+
.. code-block:: python
|
| 317 |
+
|
| 318 |
+
>>> from smart_open import open, register_compressor
|
| 319 |
+
>>> with open('smart_open/tests/test_data/1984.txt.gz', 'rb', compression='disable') as fin:
|
| 320 |
+
... print(fin.read(32))
|
| 321 |
+
b'\x1f\x8b\x08\x08\x85F\x94\\\x00\x031984.txt\x005\x8f=r\xc3@\x08\x85{\x9d\xe2\x1d@'
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
To specify the algorithm explicitly (e.g. for non-standard file extensions):
|
| 325 |
+
|
| 326 |
+
.. code-block:: python
|
| 327 |
+
|
| 328 |
+
>>> from smart_open import open, register_compressor
|
| 329 |
+
>>> with open('smart_open/tests/test_data/1984.txt.gzip', compression='.gz') as fin:
|
| 330 |
+
... print(fin.read(32))
|
| 331 |
+
It was a bright cold day in Apri
|
| 332 |
+
|
| 333 |
+
You can also easily add support for other file extensions and compression formats.
|
| 334 |
+
For example, to open xz-compressed files:
|
| 335 |
+
|
| 336 |
+
.. code-block:: python
|
| 337 |
+
|
| 338 |
+
>>> import lzma, os
|
| 339 |
+
>>> from smart_open import open, register_compressor
|
| 340 |
+
|
| 341 |
+
>>> def _handle_xz(file_obj, mode):
|
| 342 |
+
... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
|
| 343 |
+
|
| 344 |
+
>>> register_compressor('.xz', _handle_xz)
|
| 345 |
+
|
| 346 |
+
>>> with open('smart_open/tests/test_data/1984.txt.xz') as fin:
|
| 347 |
+
... print(fin.read(32))
|
| 348 |
+
It was a bright cold day in Apri
|
| 349 |
+
|
| 350 |
+
``lzma`` is in the standard library in Python 3.3 and greater.
|
| 351 |
+
For 2.7, use `backports.lzma`_.
|
| 352 |
+
|
| 353 |
+
.. _backports.lzma: https://pypi.org/project/backports.lzma/
|
| 354 |
+
|
| 355 |
+
Transport-specific Options
|
| 356 |
+
--------------------------
|
| 357 |
+
|
| 358 |
+
``smart_open`` supports a wide range of transport options out of the box, including:
|
| 359 |
+
|
| 360 |
+
- S3
|
| 361 |
+
- HTTP, HTTPS (read-only)
|
| 362 |
+
- SSH, SCP and SFTP
|
| 363 |
+
- WebHDFS
|
| 364 |
+
- GCS
|
| 365 |
+
- Azure Blob Storage
|
| 366 |
+
|
| 367 |
+
Each option involves setting up its own set of parameters.
|
| 368 |
+
For example, for accessing S3, you often need to set up authentication, like API keys or a profile name.
|
| 369 |
+
``smart_open``'s ``open`` function accepts a keyword argument ``transport_params`` which accepts additional parameters for the transport layer.
|
| 370 |
+
Here are some examples of using this parameter:
|
| 371 |
+
|
| 372 |
+
.. code-block:: python
|
| 373 |
+
|
| 374 |
+
>>> import boto3
|
| 375 |
+
>>> fin = open('s3://commoncrawl/robots.txt', transport_params=dict(client=boto3.client('s3')))
|
| 376 |
+
>>> fin = open('s3://commoncrawl/robots.txt', transport_params=dict(buffer_size=1024))
|
| 377 |
+
|
| 378 |
+
For the full list of keyword arguments supported by each transport option, see the documentation:
|
| 379 |
+
|
| 380 |
+
.. code-block:: python
|
| 381 |
+
|
| 382 |
+
help('smart_open.open')
|
| 383 |
+
|
| 384 |
+
S3 Credentials
|
| 385 |
+
--------------
|
| 386 |
+
|
| 387 |
+
``smart_open`` uses the ``boto3`` library to talk to S3.
|
| 388 |
+
``boto3`` has several `mechanisms <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html>`__ for determining the credentials to use.
|
| 389 |
+
By default, ``smart_open`` will defer to ``boto3`` and let the latter take care of the credentials.
|
| 390 |
+
There are several ways to override this behavior.
|
| 391 |
+
|
| 392 |
+
The first is to pass a ``boto3.Client`` object as a transport parameter to the ``open`` function.
|
| 393 |
+
You can customize the credentials when constructing the session for the client.
|
| 394 |
+
``smart_open`` will then use the session when talking to S3.
|
| 395 |
+
|
| 396 |
+
.. code-block:: python
|
| 397 |
+
|
| 398 |
+
session = boto3.Session(
|
| 399 |
+
aws_access_key_id=ACCESS_KEY,
|
| 400 |
+
aws_secret_access_key=SECRET_KEY,
|
| 401 |
+
aws_session_token=SESSION_TOKEN,
|
| 402 |
+
)
|
| 403 |
+
client = session.client('s3', endpoint_url=..., config=...)
|
| 404 |
+
fin = open('s3://bucket/key', transport_params={'client': client})
|
| 405 |
+
|
| 406 |
+
Your second option is to specify the credentials within the S3 URL itself:
|
| 407 |
+
|
| 408 |
+
.. code-block:: python
|
| 409 |
+
|
| 410 |
+
fin = open('s3://aws_access_key_id:aws_secret_access_key@bucket/key', ...)
|
| 411 |
+
|
| 412 |
+
*Important*: The two methods above are **mutually exclusive**. If you pass an AWS client *and* the URL contains credentials, ``smart_open`` will ignore the latter.
|
| 413 |
+
|
| 414 |
+
*Important*: ``smart_open`` ignores configuration files from the older ``boto`` library.
|
| 415 |
+
Port your old ``boto`` settings to ``boto3`` in order to use them with ``smart_open``.
|
| 416 |
+
|
| 417 |
+
S3 Advanced Usage
|
| 418 |
+
-----------------
|
| 419 |
+
|
| 420 |
+
Additional keyword arguments can be propagated to the boto3 methods that are used by ``smart_open`` under the hood using the ``client_kwargs`` transport parameter.
|
| 421 |
+
|
| 422 |
+
For instance, to upload a blob with Metadata, ACL, StorageClass, these keyword arguments can be passed to ``create_multipart_upload`` (`docs <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.create_multipart_upload>`__).
|
| 423 |
+
|
| 424 |
+
.. code-block:: python
|
| 425 |
+
|
| 426 |
+
kwargs = {'Metadata': {'version': 2}, 'ACL': 'authenticated-read', 'StorageClass': 'STANDARD_IA'}
|
| 427 |
+
fout = open('s3://bucket/key', 'wb', transport_params={'client_kwargs': {'S3.Client.create_multipart_upload': kwargs}})
|
| 428 |
+
|
| 429 |
+
Iterating Over an S3 Bucket's Contents
|
| 430 |
+
--------------------------------------
|
| 431 |
+
|
| 432 |
+
Since going over all (or select) keys in an S3 bucket is a very common operation, there's also an extra function ``smart_open.s3.iter_bucket()`` that does this efficiently, **processing the bucket keys in parallel** (using multiprocessing):
|
| 433 |
+
|
| 434 |
+
.. code-block:: python
|
| 435 |
+
|
| 436 |
+
>>> from smart_open import s3
|
| 437 |
+
>>> # we use workers=1 for reproducibility; you should use as many workers as you have cores
|
| 438 |
+
>>> bucket = 'silo-open-data'
|
| 439 |
+
>>> prefix = 'Official/annual/monthly_rain/'
|
| 440 |
+
>>> for key, content in s3.iter_bucket(bucket, prefix=prefix, accept_key=lambda key: '/201' in key, workers=1, key_limit=3):
|
| 441 |
+
... print(key, round(len(content) / 2**20))
|
| 442 |
+
Official/annual/monthly_rain/2010.monthly_rain.nc 13
|
| 443 |
+
Official/annual/monthly_rain/2011.monthly_rain.nc 13
|
| 444 |
+
Official/annual/monthly_rain/2012.monthly_rain.nc 13
|
| 445 |
+
|
| 446 |
+
GCS Credentials
|
| 447 |
+
---------------
|
| 448 |
+
``smart_open`` uses the ``google-cloud-storage`` library to talk to GCS.
|
| 449 |
+
``google-cloud-storage`` uses the ``google-cloud`` package under the hood to handle authentication.
|
| 450 |
+
There are several `options <https://googleapis.dev/python/google-api-core/latest/auth.html>`__ to provide
|
| 451 |
+
credentials.
|
| 452 |
+
By default, ``smart_open`` will defer to ``google-cloud-storage`` and let it take care of the credentials.
|
| 453 |
+
|
| 454 |
+
To override this behavior, pass a ``google.cloud.storage.Client`` object as a transport parameter to the ``open`` function.
|
| 455 |
+
You can `customize the credentials <https://googleapis.dev/python/storage/latest/client.html>`__
|
| 456 |
+
when constructing the client. ``smart_open`` will then use the client when talking to GCS. To follow allow with
|
| 457 |
+
the example below, `refer to Google's guide <https://cloud.google.com/storage/docs/reference/libraries#setting_up_authentication>`__
|
| 458 |
+
to setting up GCS authentication with a service account.
|
| 459 |
+
|
| 460 |
+
.. code-block:: python
|
| 461 |
+
|
| 462 |
+
import os
|
| 463 |
+
from google.cloud.storage import Client
|
| 464 |
+
service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
|
| 465 |
+
client = Client.from_service_account_json(service_account_path)
|
| 466 |
+
fin = open('gs://gcp-public-data-landsat/index.csv.gz', transport_params=dict(client=client))
|
| 467 |
+
|
| 468 |
+
If you need more credential options, you can create an explicit ``google.auth.credentials.Credentials`` object
|
| 469 |
+
and pass it to the Client. To create an API token for use in the example below, refer to the
|
| 470 |
+
`GCS authentication guide <https://cloud.google.com/storage/docs/authentication#apiauth>`__.
|
| 471 |
+
|
| 472 |
+
.. code-block:: python
|
| 473 |
+
|
| 474 |
+
import os
|
| 475 |
+
from google.auth.credentials import Credentials
|
| 476 |
+
from google.cloud.storage import Client
|
| 477 |
+
token = os.environ['GOOGLE_API_TOKEN']
|
| 478 |
+
credentials = Credentials(token=token)
|
| 479 |
+
client = Client(credentials=credentials)
|
| 480 |
+
fin = open('gs://gcp-public-data-landsat/index.csv.gz', transport_params={'client': client})
|
| 481 |
+
|
| 482 |
+
GCS Advanced Usage
|
| 483 |
+
------------------
|
| 484 |
+
|
| 485 |
+
Additional keyword arguments can be propagated to the GCS open method (`docs <https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.blob.Blob#google_cloud_storage_blob_Blob_open>`__), which is used by ``smart_open`` under the hood, using the ``blob_open_kwargs`` transport parameter.
|
| 486 |
+
|
| 487 |
+
Additionally keyword arguments can be propagated to the GCS ``get_blob`` method (`docs <https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.bucket.Bucket#google_cloud_storage_bucket_Bucket_get_blob>`__) when in a read-mode, using the ``get_blob_kwargs`` transport parameter.
|
| 488 |
+
|
| 489 |
+
Additional blob properties (`docs <https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.blob.Blob#properties>`__) can be set before an upload, as long as they are not read-only, using the ``blob_properties`` transport parameter.
|
| 490 |
+
|
| 491 |
+
.. code-block:: python
|
| 492 |
+
|
| 493 |
+
open_kwargs = {'predefined_acl': 'authenticated-read'}
|
| 494 |
+
properties = {'metadata': {'version': 2}, 'storage_class': 'COLDLINE'}
|
| 495 |
+
fout = open('gs://bucket/key', 'wb', transport_params={'blob_open_kwargs': open_kwargs, 'blob_properties': properties})
|
| 496 |
+
|
| 497 |
+
Azure Credentials
|
| 498 |
+
-----------------
|
| 499 |
+
|
| 500 |
+
``smart_open`` uses the ``azure-storage-blob`` library to talk to Azure Blob Storage.
|
| 501 |
+
By default, ``smart_open`` will defer to ``azure-storage-blob`` and let it take care of the credentials.
|
| 502 |
+
|
| 503 |
+
Azure Blob Storage does not have any ways of inferring credentials therefore, passing a ``azure.storage.blob.BlobServiceClient``
|
| 504 |
+
object as a transport parameter to the ``open`` function is required.
|
| 505 |
+
You can `customize the credentials <https://docs.microsoft.com/en-us/azure/storage/common/storage-samples-python#authentication>`__
|
| 506 |
+
when constructing the client. ``smart_open`` will then use the client when talking to. To follow allow with
|
| 507 |
+
the example below, `refer to Azure's guide <https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python#copy-your-credentials-from-the-azure-portal>`__
|
| 508 |
+
to setting up authentication.
|
| 509 |
+
|
| 510 |
+
.. code-block:: python
|
| 511 |
+
|
| 512 |
+
import os
|
| 513 |
+
from azure.storage.blob import BlobServiceClient
|
| 514 |
+
azure_storage_connection_string = os.environ['AZURE_STORAGE_CONNECTION_STRING']
|
| 515 |
+
client = BlobServiceClient.from_connection_string(azure_storage_connection_string)
|
| 516 |
+
fin = open('azure://my_container/my_blob.txt', transport_params={'client': client})
|
| 517 |
+
|
| 518 |
+
If you need more credential options, refer to the
|
| 519 |
+
`Azure Storage authentication guide <https://docs.microsoft.com/en-us/azure/storage/common/storage-samples-python#authentication>`__.
|
| 520 |
+
|
| 521 |
+
Azure Advanced Usage
|
| 522 |
+
--------------------
|
| 523 |
+
|
| 524 |
+
Additional keyword arguments can be propagated to the ``commit_block_list`` method (`docs <https://azuresdkdocs.blob.core.windows.net/$web/python/azure-storage-blob/12.14.1/azure.storage.blob.html#azure.storage.blob.BlobClient.commit_block_list>`__), which is used by ``smart_open`` under the hood for uploads, using the ``blob_kwargs`` transport parameter.
|
| 525 |
+
|
| 526 |
+
.. code-block:: python
|
| 527 |
+
|
| 528 |
+
kwargs = {'metadata': {'version': 2}}
|
| 529 |
+
fout = open('azure://container/key', 'wb', transport_params={'blob_kwargs': kwargs})
|
| 530 |
+
|
| 531 |
+
Drop-in replacement of ``pathlib.Path.open``
|
| 532 |
+
--------------------------------------------
|
| 533 |
+
|
| 534 |
+
``smart_open.open`` can also be used with ``Path`` objects.
|
| 535 |
+
The built-in `Path.open()` is not able to read text from compressed files, so use ``patch_pathlib`` to replace it with `smart_open.open()` instead.
|
| 536 |
+
This can be helpful when e.g. working with compressed files.
|
| 537 |
+
|
| 538 |
+
.. code-block:: python
|
| 539 |
+
|
| 540 |
+
>>> from pathlib import Path
|
| 541 |
+
>>> from smart_open.smart_open_lib import patch_pathlib
|
| 542 |
+
>>>
|
| 543 |
+
>>> _ = patch_pathlib() # replace `Path.open` with `smart_open.open`
|
| 544 |
+
>>>
|
| 545 |
+
>>> path = Path("smart_open/tests/test_data/crime-and-punishment.txt.gz")
|
| 546 |
+
>>>
|
| 547 |
+
>>> with path.open("r") as infile:
|
| 548 |
+
... print(infile.readline()[:41])
|
| 549 |
+
В начале июля, в чрезвычайно жаркое время
|
| 550 |
+
|
| 551 |
+
How do I ...?
|
| 552 |
+
=============
|
| 553 |
+
|
| 554 |
+
See `this document <howto.md>`__.
|
| 555 |
+
|
| 556 |
+
Extending ``smart_open``
|
| 557 |
+
========================
|
| 558 |
+
|
| 559 |
+
See `this document <extending.md>`__.
|
| 560 |
+
|
| 561 |
+
Testing ``smart_open``
|
| 562 |
+
======================
|
| 563 |
+
|
| 564 |
+
``smart_open`` comes with a comprehensive suite of unit tests.
|
| 565 |
+
Before you can run the test suite, install the test dependencies::
|
| 566 |
+
|
| 567 |
+
pip install -e .[test]
|
| 568 |
+
|
| 569 |
+
Now, you can run the unit tests::
|
| 570 |
+
|
| 571 |
+
pytest smart_open
|
| 572 |
+
|
| 573 |
+
The tests are also run automatically with `Travis CI <https://travis-ci.org/RaRe-Technologies/smart_open>`_ on every commit push & pull request.
|
| 574 |
+
|
| 575 |
+
Comments, bug reports
|
| 576 |
+
=====================
|
| 577 |
+
|
| 578 |
+
``smart_open`` lives on `Github <https://github.com/RaRe-Technologies/smart_open>`_. You can file
|
| 579 |
+
issues or pull requests there. Suggestions, pull requests and improvements welcome!
|
| 580 |
+
|
| 581 |
+
----------------
|
| 582 |
+
|
| 583 |
+
``smart_open`` is open source software released under the `MIT license <https://github.com/piskvorky/smart_open/blob/master/LICENSE>`_.
|
| 584 |
+
Copyright (c) 2015-now `Radim Řehůřek <https://radimrehurek.com>`_.
|
| 585 |
+
|
| 586 |
+
|
.venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: bdist_wheel (0.45.1)
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py3-none-any
|
| 5 |
+
|
.venv/lib/python3.11/site-packages/torchvision/__init__.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
from modulefinder import Module
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
# Don't re-order these, we need to load the _C extension (done when importing
|
| 8 |
+
# .extensions) before entering _meta_registrations.
|
| 9 |
+
from .extension import _HAS_OPS # usort:skip
|
| 10 |
+
from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .version import __version__ # noqa: F401
|
| 14 |
+
except ImportError:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# Check if torchvision is being imported within the root folder
|
| 19 |
+
if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
|
| 20 |
+
os.path.realpath(os.getcwd()), "torchvision"
|
| 21 |
+
):
|
| 22 |
+
message = (
|
| 23 |
+
"You are importing torchvision within its own root folder ({}). "
|
| 24 |
+
"This is not expected to work and may give errors. Please exit the "
|
| 25 |
+
"torchvision project source and relaunch your python interpreter."
|
| 26 |
+
)
|
| 27 |
+
warnings.warn(message.format(os.getcwd()))
|
| 28 |
+
|
| 29 |
+
_image_backend = "PIL"
|
| 30 |
+
|
| 31 |
+
_video_backend = "pyav"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def set_image_backend(backend):
|
| 35 |
+
"""
|
| 36 |
+
Specifies the package used to load images.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
|
| 40 |
+
The :mod:`accimage` package uses the Intel IPP library. It is
|
| 41 |
+
generally faster than PIL, but does not support as many operations.
|
| 42 |
+
"""
|
| 43 |
+
global _image_backend
|
| 44 |
+
if backend not in ["PIL", "accimage"]:
|
| 45 |
+
raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
|
| 46 |
+
_image_backend = backend
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_image_backend():
|
| 50 |
+
"""
|
| 51 |
+
Gets the name of the package used to load images
|
| 52 |
+
"""
|
| 53 |
+
return _image_backend
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def set_video_backend(backend):
|
| 57 |
+
"""
|
| 58 |
+
Specifies the package used to decode videos.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
|
| 62 |
+
The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
|
| 63 |
+
binding for the FFmpeg libraries.
|
| 64 |
+
The :mod:`video_reader` package includes a native C++ implementation on
|
| 65 |
+
top of FFMPEG libraries, and a python API of TorchScript custom operator.
|
| 66 |
+
It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
|
| 67 |
+
|
| 68 |
+
.. note::
|
| 69 |
+
Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
|
| 70 |
+
backend, please compile torchvision from source.
|
| 71 |
+
"""
|
| 72 |
+
global _video_backend
|
| 73 |
+
if backend not in ["pyav", "video_reader", "cuda"]:
|
| 74 |
+
raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
|
| 75 |
+
if backend == "video_reader" and not io._HAS_CPU_VIDEO_DECODER:
|
| 76 |
+
# TODO: better messages
|
| 77 |
+
message = "video_reader video backend is not available. Please compile torchvision from source and try again"
|
| 78 |
+
raise RuntimeError(message)
|
| 79 |
+
elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
|
| 80 |
+
# TODO: better messages
|
| 81 |
+
message = "cuda video backend is not available."
|
| 82 |
+
raise RuntimeError(message)
|
| 83 |
+
else:
|
| 84 |
+
_video_backend = backend
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_video_backend():
|
| 88 |
+
"""
|
| 89 |
+
Returns the currently active video backend used to decode videos.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
str: Name of the video backend. one of {'pyav', 'video_reader'}.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
return _video_backend
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _is_tracing():
|
| 99 |
+
return torch._C._get_tracing_state()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def disable_beta_transforms_warning():
|
| 103 |
+
# Noop, only exists to avoid breaking existing code.
|
| 104 |
+
# See https://github.com/pytorch/vision/issues/7896
|
| 105 |
+
pass
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.53 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/_internally_replaced_utils.cpython-311.pyc
ADDED
|
Binary file (2.44 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/_meta_registrations.cpython-311.pyc
ADDED
|
Binary file (12.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/extension.cpython-311.pyc
ADDED
|
Binary file (4.05 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (37.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/__pycache__/version.cpython-311.pyc
ADDED
|
Binary file (460 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/_internally_replaced_utils.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.machinery
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from torch.hub import _get_torch_home
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
_HOME = os.path.join(_get_torch_home(), "datasets", "vision")
|
| 8 |
+
_USE_SHARDED_DATASETS = False
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _download_file_from_remote_location(fpath: str, url: str) -> None:
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _is_remote_location_available() -> bool:
|
| 16 |
+
return False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from torch.hub import load_state_dict_from_url # noqa: 401
|
| 21 |
+
except ImportError:
|
| 22 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_extension_path(lib_name):
|
| 26 |
+
|
| 27 |
+
lib_dir = os.path.dirname(__file__)
|
| 28 |
+
if os.name == "nt":
|
| 29 |
+
# Register the main torchvision library location on the default DLL path
|
| 30 |
+
import ctypes
|
| 31 |
+
|
| 32 |
+
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
|
| 33 |
+
with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
|
| 34 |
+
prev_error_mode = kernel32.SetErrorMode(0x0001)
|
| 35 |
+
|
| 36 |
+
if with_load_library_flags:
|
| 37 |
+
kernel32.AddDllDirectory.restype = ctypes.c_void_p
|
| 38 |
+
|
| 39 |
+
os.add_dll_directory(lib_dir)
|
| 40 |
+
|
| 41 |
+
kernel32.SetErrorMode(prev_error_mode)
|
| 42 |
+
|
| 43 |
+
loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
|
| 44 |
+
|
| 45 |
+
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
|
| 46 |
+
ext_specs = extfinder.find_spec(lib_name)
|
| 47 |
+
if ext_specs is None:
|
| 48 |
+
raise ImportError
|
| 49 |
+
|
| 50 |
+
return ext_specs.origin
|
.venv/lib/python3.11/site-packages/torchvision/_meta_registrations.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch._custom_ops
|
| 5 |
+
import torch.library
|
| 6 |
+
|
| 7 |
+
# Ensure that torch.ops.torchvision is visible
|
| 8 |
+
import torchvision.extension # noqa: F401
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@functools.lru_cache(None)
|
| 12 |
+
def get_meta_lib():
|
| 13 |
+
return torch.library.Library("torchvision", "IMPL", "Meta")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def register_meta(op_name, overload_name="default"):
|
| 17 |
+
def wrapper(fn):
|
| 18 |
+
if torchvision.extension._has_ops():
|
| 19 |
+
get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
|
| 20 |
+
return fn
|
| 21 |
+
|
| 22 |
+
return wrapper
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@register_meta("roi_align")
|
| 26 |
+
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
|
| 27 |
+
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
| 28 |
+
torch._check(
|
| 29 |
+
input.dtype == rois.dtype,
|
| 30 |
+
lambda: (
|
| 31 |
+
"Expected tensor for input to have the same type as tensor for rois; "
|
| 32 |
+
f"but type {input.dtype} does not equal {rois.dtype}"
|
| 33 |
+
),
|
| 34 |
+
)
|
| 35 |
+
num_rois = rois.size(0)
|
| 36 |
+
channels = input.size(1)
|
| 37 |
+
return input.new_empty((num_rois, channels, pooled_height, pooled_width))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@register_meta("_roi_align_backward")
|
| 41 |
+
def meta_roi_align_backward(
|
| 42 |
+
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
|
| 43 |
+
):
|
| 44 |
+
torch._check(
|
| 45 |
+
grad.dtype == rois.dtype,
|
| 46 |
+
lambda: (
|
| 47 |
+
"Expected tensor for grad to have the same type as tensor for rois; "
|
| 48 |
+
f"but type {grad.dtype} does not equal {rois.dtype}"
|
| 49 |
+
),
|
| 50 |
+
)
|
| 51 |
+
return grad.new_empty((batch_size, channels, height, width))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@register_meta("ps_roi_align")
|
| 55 |
+
def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
|
| 56 |
+
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
| 57 |
+
torch._check(
|
| 58 |
+
input.dtype == rois.dtype,
|
| 59 |
+
lambda: (
|
| 60 |
+
"Expected tensor for input to have the same type as tensor for rois; "
|
| 61 |
+
f"but type {input.dtype} does not equal {rois.dtype}"
|
| 62 |
+
),
|
| 63 |
+
)
|
| 64 |
+
channels = input.size(1)
|
| 65 |
+
torch._check(
|
| 66 |
+
channels % (pooled_height * pooled_width) == 0,
|
| 67 |
+
"input channels must be a multiple of pooling height * pooling width",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
num_rois = rois.size(0)
|
| 71 |
+
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
|
| 72 |
+
return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@register_meta("_ps_roi_align_backward")
|
| 76 |
+
def meta_ps_roi_align_backward(
|
| 77 |
+
grad,
|
| 78 |
+
rois,
|
| 79 |
+
channel_mapping,
|
| 80 |
+
spatial_scale,
|
| 81 |
+
pooled_height,
|
| 82 |
+
pooled_width,
|
| 83 |
+
sampling_ratio,
|
| 84 |
+
batch_size,
|
| 85 |
+
channels,
|
| 86 |
+
height,
|
| 87 |
+
width,
|
| 88 |
+
):
|
| 89 |
+
torch._check(
|
| 90 |
+
grad.dtype == rois.dtype,
|
| 91 |
+
lambda: (
|
| 92 |
+
"Expected tensor for grad to have the same type as tensor for rois; "
|
| 93 |
+
f"but type {grad.dtype} does not equal {rois.dtype}"
|
| 94 |
+
),
|
| 95 |
+
)
|
| 96 |
+
return grad.new_empty((batch_size, channels, height, width))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@register_meta("roi_pool")
|
| 100 |
+
def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
|
| 101 |
+
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
| 102 |
+
torch._check(
|
| 103 |
+
input.dtype == rois.dtype,
|
| 104 |
+
lambda: (
|
| 105 |
+
"Expected tensor for input to have the same type as tensor for rois; "
|
| 106 |
+
f"but type {input.dtype} does not equal {rois.dtype}"
|
| 107 |
+
),
|
| 108 |
+
)
|
| 109 |
+
num_rois = rois.size(0)
|
| 110 |
+
channels = input.size(1)
|
| 111 |
+
out_size = (num_rois, channels, pooled_height, pooled_width)
|
| 112 |
+
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@register_meta("_roi_pool_backward")
|
| 116 |
+
def meta_roi_pool_backward(
|
| 117 |
+
grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
|
| 118 |
+
):
|
| 119 |
+
torch._check(
|
| 120 |
+
grad.dtype == rois.dtype,
|
| 121 |
+
lambda: (
|
| 122 |
+
"Expected tensor for grad to have the same type as tensor for rois; "
|
| 123 |
+
f"but type {grad.dtype} does not equal {rois.dtype}"
|
| 124 |
+
),
|
| 125 |
+
)
|
| 126 |
+
return grad.new_empty((batch_size, channels, height, width))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
@register_meta("ps_roi_pool")
|
| 130 |
+
def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
|
| 131 |
+
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
|
| 132 |
+
torch._check(
|
| 133 |
+
input.dtype == rois.dtype,
|
| 134 |
+
lambda: (
|
| 135 |
+
"Expected tensor for input to have the same type as tensor for rois; "
|
| 136 |
+
f"but type {input.dtype} does not equal {rois.dtype}"
|
| 137 |
+
),
|
| 138 |
+
)
|
| 139 |
+
channels = input.size(1)
|
| 140 |
+
torch._check(
|
| 141 |
+
channels % (pooled_height * pooled_width) == 0,
|
| 142 |
+
"input channels must be a multiple of pooling height * pooling width",
|
| 143 |
+
)
|
| 144 |
+
num_rois = rois.size(0)
|
| 145 |
+
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
|
| 146 |
+
return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
@register_meta("_ps_roi_pool_backward")
|
| 150 |
+
def meta_ps_roi_pool_backward(
|
| 151 |
+
grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
|
| 152 |
+
):
|
| 153 |
+
torch._check(
|
| 154 |
+
grad.dtype == rois.dtype,
|
| 155 |
+
lambda: (
|
| 156 |
+
"Expected tensor for grad to have the same type as tensor for rois; "
|
| 157 |
+
f"but type {grad.dtype} does not equal {rois.dtype}"
|
| 158 |
+
),
|
| 159 |
+
)
|
| 160 |
+
return grad.new_empty((batch_size, channels, height, width))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@torch.library.register_fake("torchvision::nms")
|
| 164 |
+
def meta_nms(dets, scores, iou_threshold):
|
| 165 |
+
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
|
| 166 |
+
torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
|
| 167 |
+
torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
|
| 168 |
+
torch._check(
|
| 169 |
+
dets.size(0) == scores.size(0),
|
| 170 |
+
lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
|
| 171 |
+
)
|
| 172 |
+
ctx = torch._custom_ops.get_ctx()
|
| 173 |
+
num_to_keep = ctx.create_unbacked_symint()
|
| 174 |
+
return dets.new_empty(num_to_keep, dtype=torch.long)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
@register_meta("deform_conv2d")
|
| 178 |
+
def meta_deform_conv2d(
|
| 179 |
+
input,
|
| 180 |
+
weight,
|
| 181 |
+
offset,
|
| 182 |
+
mask,
|
| 183 |
+
bias,
|
| 184 |
+
stride_h,
|
| 185 |
+
stride_w,
|
| 186 |
+
pad_h,
|
| 187 |
+
pad_w,
|
| 188 |
+
dil_h,
|
| 189 |
+
dil_w,
|
| 190 |
+
n_weight_grps,
|
| 191 |
+
n_offset_grps,
|
| 192 |
+
use_mask,
|
| 193 |
+
):
|
| 194 |
+
|
| 195 |
+
out_height, out_width = offset.shape[-2:]
|
| 196 |
+
out_channels = weight.shape[0]
|
| 197 |
+
batch_size = input.shape[0]
|
| 198 |
+
return input.new_empty((batch_size, out_channels, out_height, out_width))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
@register_meta("_deform_conv2d_backward")
|
| 202 |
+
def meta_deform_conv2d_backward(
|
| 203 |
+
grad,
|
| 204 |
+
input,
|
| 205 |
+
weight,
|
| 206 |
+
offset,
|
| 207 |
+
mask,
|
| 208 |
+
bias,
|
| 209 |
+
stride_h,
|
| 210 |
+
stride_w,
|
| 211 |
+
pad_h,
|
| 212 |
+
pad_w,
|
| 213 |
+
dilation_h,
|
| 214 |
+
dilation_w,
|
| 215 |
+
groups,
|
| 216 |
+
offset_groups,
|
| 217 |
+
use_mask,
|
| 218 |
+
):
|
| 219 |
+
|
| 220 |
+
grad_input = input.new_empty(input.shape)
|
| 221 |
+
grad_weight = weight.new_empty(weight.shape)
|
| 222 |
+
grad_offset = offset.new_empty(offset.shape)
|
| 223 |
+
grad_mask = mask.new_empty(mask.shape)
|
| 224 |
+
grad_bias = bias.new_empty(bias.shape)
|
| 225 |
+
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
|
.venv/lib/python3.11/site-packages/torchvision/_utils.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
from typing import Sequence, Type, TypeVar
|
| 3 |
+
|
| 4 |
+
T = TypeVar("T", bound=enum.Enum)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class StrEnumMeta(enum.EnumMeta):
|
| 8 |
+
auto = enum.auto
|
| 9 |
+
|
| 10 |
+
def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
|
| 11 |
+
try:
|
| 12 |
+
return self[member]
|
| 13 |
+
except KeyError:
|
| 14 |
+
# TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
|
| 15 |
+
# soon as it is migrated.
|
| 16 |
+
raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class StrEnum(enum.Enum, metaclass=StrEnumMeta):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
|
| 24 |
+
if not seq:
|
| 25 |
+
return ""
|
| 26 |
+
if len(seq) == 1:
|
| 27 |
+
return f"'{seq[0]}'"
|
| 28 |
+
|
| 29 |
+
head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
|
| 30 |
+
tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
|
| 31 |
+
|
| 32 |
+
return head + tail
|
.venv/lib/python3.11/site-packages/torchvision/extension.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from ._internally_replaced_utils import _get_extension_path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_HAS_OPS = False
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _has_ops():
|
| 13 |
+
return False
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
# On Windows Python-3.8.x has `os.add_dll_directory` call,
|
| 18 |
+
# which is called to configure dll search path.
|
| 19 |
+
# To find cuda related dlls we need to make sure the
|
| 20 |
+
# conda environment/bin path is configured Please take a look:
|
| 21 |
+
# https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
|
| 22 |
+
# Please note: if some path can't be added using add_dll_directory we simply ignore this path
|
| 23 |
+
if os.name == "nt" and sys.version_info < (3, 9):
|
| 24 |
+
env_path = os.environ["PATH"]
|
| 25 |
+
path_arr = env_path.split(";")
|
| 26 |
+
for path in path_arr:
|
| 27 |
+
if os.path.exists(path):
|
| 28 |
+
try:
|
| 29 |
+
os.add_dll_directory(path) # type: ignore[attr-defined]
|
| 30 |
+
except Exception:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
lib_path = _get_extension_path("_C")
|
| 34 |
+
torch.ops.load_library(lib_path)
|
| 35 |
+
_HAS_OPS = True
|
| 36 |
+
|
| 37 |
+
def _has_ops(): # noqa: F811
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
except (ImportError, OSError):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _assert_has_ops():
|
| 45 |
+
if not _has_ops():
|
| 46 |
+
raise RuntimeError(
|
| 47 |
+
"Couldn't load custom C++ ops. This can happen if your PyTorch and "
|
| 48 |
+
"torchvision versions are incompatible, or if you had errors while compiling "
|
| 49 |
+
"torchvision from source. For further information on the compatible versions, check "
|
| 50 |
+
"https://github.com/pytorch/vision#installation for the compatibility matrix. "
|
| 51 |
+
"Please check your PyTorch version with torch.__version__ and your torchvision "
|
| 52 |
+
"version with torchvision.__version__ and verify if they are compatible, and if not "
|
| 53 |
+
"please reinstall torchvision so that it matches your PyTorch install."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _check_cuda_version():
|
| 58 |
+
"""
|
| 59 |
+
Make sure that CUDA versions match between the pytorch install and torchvision install
|
| 60 |
+
"""
|
| 61 |
+
if not _HAS_OPS:
|
| 62 |
+
return -1
|
| 63 |
+
from torch.version import cuda as torch_version_cuda
|
| 64 |
+
|
| 65 |
+
_version = torch.ops.torchvision._cuda_version()
|
| 66 |
+
if _version != -1 and torch_version_cuda is not None:
|
| 67 |
+
tv_version = str(_version)
|
| 68 |
+
if int(tv_version) < 10000:
|
| 69 |
+
tv_major = int(tv_version[0])
|
| 70 |
+
tv_minor = int(tv_version[2])
|
| 71 |
+
else:
|
| 72 |
+
tv_major = int(tv_version[0:2])
|
| 73 |
+
tv_minor = int(tv_version[3])
|
| 74 |
+
t_version = torch_version_cuda.split(".")
|
| 75 |
+
t_major = int(t_version[0])
|
| 76 |
+
t_minor = int(t_version[1])
|
| 77 |
+
if t_major != tv_major:
|
| 78 |
+
raise RuntimeError(
|
| 79 |
+
"Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
|
| 80 |
+
f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
|
| 81 |
+
f"CUDA Version={tv_major}.{tv_minor}. "
|
| 82 |
+
"Please reinstall the torchvision that matches your PyTorch install."
|
| 83 |
+
)
|
| 84 |
+
return _version
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _load_library(lib_name):
|
| 88 |
+
lib_path = _get_extension_path(lib_name)
|
| 89 |
+
torch.ops.load_library(lib_path)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
_check_cuda_version()
|
.venv/lib/python3.11/site-packages/torchvision/models/alexnet.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from ..transforms._presets import ImageClassification
|
| 8 |
+
from ..utils import _log_api_usage_once
|
| 9 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 10 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 11 |
+
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AlexNet(nn.Module):
|
| 18 |
+
def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
_log_api_usage_once(self)
|
| 21 |
+
self.features = nn.Sequential(
|
| 22 |
+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
| 23 |
+
nn.ReLU(inplace=True),
|
| 24 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
| 25 |
+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
| 26 |
+
nn.ReLU(inplace=True),
|
| 27 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
| 28 |
+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
| 29 |
+
nn.ReLU(inplace=True),
|
| 30 |
+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
| 31 |
+
nn.ReLU(inplace=True),
|
| 32 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
| 35 |
+
)
|
| 36 |
+
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
|
| 37 |
+
self.classifier = nn.Sequential(
|
| 38 |
+
nn.Dropout(p=dropout),
|
| 39 |
+
nn.Linear(256 * 6 * 6, 4096),
|
| 40 |
+
nn.ReLU(inplace=True),
|
| 41 |
+
nn.Dropout(p=dropout),
|
| 42 |
+
nn.Linear(4096, 4096),
|
| 43 |
+
nn.ReLU(inplace=True),
|
| 44 |
+
nn.Linear(4096, num_classes),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
x = self.features(x)
|
| 49 |
+
x = self.avgpool(x)
|
| 50 |
+
x = torch.flatten(x, 1)
|
| 51 |
+
x = self.classifier(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class AlexNet_Weights(WeightsEnum):
|
| 56 |
+
IMAGENET1K_V1 = Weights(
|
| 57 |
+
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
|
| 58 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 59 |
+
meta={
|
| 60 |
+
"num_params": 61100840,
|
| 61 |
+
"min_size": (63, 63),
|
| 62 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 63 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
|
| 64 |
+
"_metrics": {
|
| 65 |
+
"ImageNet-1K": {
|
| 66 |
+
"acc@1": 56.522,
|
| 67 |
+
"acc@5": 79.066,
|
| 68 |
+
}
|
| 69 |
+
},
|
| 70 |
+
"_ops": 0.714,
|
| 71 |
+
"_file_size": 233.087,
|
| 72 |
+
"_docs": """
|
| 73 |
+
These weights reproduce closely the results of the paper using a simplified training recipe.
|
| 74 |
+
""",
|
| 75 |
+
},
|
| 76 |
+
)
|
| 77 |
+
DEFAULT = IMAGENET1K_V1
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@register_model()
|
| 81 |
+
@handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
|
| 82 |
+
def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
|
| 83 |
+
"""AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
|
| 84 |
+
|
| 85 |
+
.. note::
|
| 86 |
+
AlexNet was originally introduced in the `ImageNet Classification with
|
| 87 |
+
Deep Convolutional Neural Networks
|
| 88 |
+
<https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
|
| 89 |
+
paper. Our implementation is based instead on the "One weird trick"
|
| 90 |
+
paper above.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
|
| 94 |
+
pretrained weights to use. See
|
| 95 |
+
:class:`~torchvision.models.AlexNet_Weights` below for
|
| 96 |
+
more details, and possible values. By default, no pre-trained
|
| 97 |
+
weights are used.
|
| 98 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 99 |
+
download to stderr. Default is True.
|
| 100 |
+
**kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
|
| 101 |
+
base class. Please refer to the `source code
|
| 102 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
|
| 103 |
+
for more details about this class.
|
| 104 |
+
|
| 105 |
+
.. autoclass:: torchvision.models.AlexNet_Weights
|
| 106 |
+
:members:
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
weights = AlexNet_Weights.verify(weights)
|
| 110 |
+
|
| 111 |
+
if weights is not None:
|
| 112 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 113 |
+
|
| 114 |
+
model = AlexNet(**kwargs)
|
| 115 |
+
|
| 116 |
+
if weights is not None:
|
| 117 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 118 |
+
|
| 119 |
+
return model
|
.venv/lib/python3.11/site-packages/torchvision/models/convnext.py
ADDED
|
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Callable, List, Optional, Sequence
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, Tensor
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from ..ops.misc import Conv2dNormActivation, Permute
|
| 9 |
+
from ..ops.stochastic_depth import StochasticDepth
|
| 10 |
+
from ..transforms._presets import ImageClassification
|
| 11 |
+
from ..utils import _log_api_usage_once
|
| 12 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 13 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 14 |
+
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"ConvNeXt",
|
| 19 |
+
"ConvNeXt_Tiny_Weights",
|
| 20 |
+
"ConvNeXt_Small_Weights",
|
| 21 |
+
"ConvNeXt_Base_Weights",
|
| 22 |
+
"ConvNeXt_Large_Weights",
|
| 23 |
+
"convnext_tiny",
|
| 24 |
+
"convnext_small",
|
| 25 |
+
"convnext_base",
|
| 26 |
+
"convnext_large",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class LayerNorm2d(nn.LayerNorm):
|
| 31 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 32 |
+
x = x.permute(0, 2, 3, 1)
|
| 33 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
| 34 |
+
x = x.permute(0, 3, 1, 2)
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CNBlock(nn.Module):
|
| 39 |
+
def __init__(
|
| 40 |
+
self,
|
| 41 |
+
dim,
|
| 42 |
+
layer_scale: float,
|
| 43 |
+
stochastic_depth_prob: float,
|
| 44 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
if norm_layer is None:
|
| 48 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 49 |
+
|
| 50 |
+
self.block = nn.Sequential(
|
| 51 |
+
nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
|
| 52 |
+
Permute([0, 2, 3, 1]),
|
| 53 |
+
norm_layer(dim),
|
| 54 |
+
nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
|
| 55 |
+
nn.GELU(),
|
| 56 |
+
nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
|
| 57 |
+
Permute([0, 3, 1, 2]),
|
| 58 |
+
)
|
| 59 |
+
self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
|
| 60 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
| 61 |
+
|
| 62 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 63 |
+
result = self.layer_scale * self.block(input)
|
| 64 |
+
result = self.stochastic_depth(result)
|
| 65 |
+
result += input
|
| 66 |
+
return result
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class CNBlockConfig:
|
| 70 |
+
# Stores information listed at Section 3 of the ConvNeXt paper
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
input_channels: int,
|
| 74 |
+
out_channels: Optional[int],
|
| 75 |
+
num_layers: int,
|
| 76 |
+
) -> None:
|
| 77 |
+
self.input_channels = input_channels
|
| 78 |
+
self.out_channels = out_channels
|
| 79 |
+
self.num_layers = num_layers
|
| 80 |
+
|
| 81 |
+
def __repr__(self) -> str:
|
| 82 |
+
s = self.__class__.__name__ + "("
|
| 83 |
+
s += "input_channels={input_channels}"
|
| 84 |
+
s += ", out_channels={out_channels}"
|
| 85 |
+
s += ", num_layers={num_layers}"
|
| 86 |
+
s += ")"
|
| 87 |
+
return s.format(**self.__dict__)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class ConvNeXt(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
block_setting: List[CNBlockConfig],
|
| 94 |
+
stochastic_depth_prob: float = 0.0,
|
| 95 |
+
layer_scale: float = 1e-6,
|
| 96 |
+
num_classes: int = 1000,
|
| 97 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
| 98 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 99 |
+
**kwargs: Any,
|
| 100 |
+
) -> None:
|
| 101 |
+
super().__init__()
|
| 102 |
+
_log_api_usage_once(self)
|
| 103 |
+
|
| 104 |
+
if not block_setting:
|
| 105 |
+
raise ValueError("The block_setting should not be empty")
|
| 106 |
+
elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
|
| 107 |
+
raise TypeError("The block_setting should be List[CNBlockConfig]")
|
| 108 |
+
|
| 109 |
+
if block is None:
|
| 110 |
+
block = CNBlock
|
| 111 |
+
|
| 112 |
+
if norm_layer is None:
|
| 113 |
+
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
| 114 |
+
|
| 115 |
+
layers: List[nn.Module] = []
|
| 116 |
+
|
| 117 |
+
# Stem
|
| 118 |
+
firstconv_output_channels = block_setting[0].input_channels
|
| 119 |
+
layers.append(
|
| 120 |
+
Conv2dNormActivation(
|
| 121 |
+
3,
|
| 122 |
+
firstconv_output_channels,
|
| 123 |
+
kernel_size=4,
|
| 124 |
+
stride=4,
|
| 125 |
+
padding=0,
|
| 126 |
+
norm_layer=norm_layer,
|
| 127 |
+
activation_layer=None,
|
| 128 |
+
bias=True,
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
|
| 133 |
+
stage_block_id = 0
|
| 134 |
+
for cnf in block_setting:
|
| 135 |
+
# Bottlenecks
|
| 136 |
+
stage: List[nn.Module] = []
|
| 137 |
+
for _ in range(cnf.num_layers):
|
| 138 |
+
# adjust stochastic depth probability based on the depth of the stage block
|
| 139 |
+
sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
|
| 140 |
+
stage.append(block(cnf.input_channels, layer_scale, sd_prob))
|
| 141 |
+
stage_block_id += 1
|
| 142 |
+
layers.append(nn.Sequential(*stage))
|
| 143 |
+
if cnf.out_channels is not None:
|
| 144 |
+
# Downsampling
|
| 145 |
+
layers.append(
|
| 146 |
+
nn.Sequential(
|
| 147 |
+
norm_layer(cnf.input_channels),
|
| 148 |
+
nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.features = nn.Sequential(*layers)
|
| 153 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 154 |
+
|
| 155 |
+
lastblock = block_setting[-1]
|
| 156 |
+
lastconv_output_channels = (
|
| 157 |
+
lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
|
| 158 |
+
)
|
| 159 |
+
self.classifier = nn.Sequential(
|
| 160 |
+
norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
for m in self.modules():
|
| 164 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 165 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
| 166 |
+
if m.bias is not None:
|
| 167 |
+
nn.init.zeros_(m.bias)
|
| 168 |
+
|
| 169 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 170 |
+
x = self.features(x)
|
| 171 |
+
x = self.avgpool(x)
|
| 172 |
+
x = self.classifier(x)
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 176 |
+
return self._forward_impl(x)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _convnext(
|
| 180 |
+
block_setting: List[CNBlockConfig],
|
| 181 |
+
stochastic_depth_prob: float,
|
| 182 |
+
weights: Optional[WeightsEnum],
|
| 183 |
+
progress: bool,
|
| 184 |
+
**kwargs: Any,
|
| 185 |
+
) -> ConvNeXt:
|
| 186 |
+
if weights is not None:
|
| 187 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 188 |
+
|
| 189 |
+
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
|
| 190 |
+
|
| 191 |
+
if weights is not None:
|
| 192 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 193 |
+
|
| 194 |
+
return model
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
_COMMON_META = {
|
| 198 |
+
"min_size": (32, 32),
|
| 199 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 200 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
|
| 201 |
+
"_docs": """
|
| 202 |
+
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
| 203 |
+
`new training recipe
|
| 204 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 205 |
+
""",
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class ConvNeXt_Tiny_Weights(WeightsEnum):
|
| 210 |
+
IMAGENET1K_V1 = Weights(
|
| 211 |
+
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
|
| 212 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=236),
|
| 213 |
+
meta={
|
| 214 |
+
**_COMMON_META,
|
| 215 |
+
"num_params": 28589128,
|
| 216 |
+
"_metrics": {
|
| 217 |
+
"ImageNet-1K": {
|
| 218 |
+
"acc@1": 82.520,
|
| 219 |
+
"acc@5": 96.146,
|
| 220 |
+
}
|
| 221 |
+
},
|
| 222 |
+
"_ops": 4.456,
|
| 223 |
+
"_file_size": 109.119,
|
| 224 |
+
},
|
| 225 |
+
)
|
| 226 |
+
DEFAULT = IMAGENET1K_V1
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ConvNeXt_Small_Weights(WeightsEnum):
|
| 230 |
+
IMAGENET1K_V1 = Weights(
|
| 231 |
+
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
|
| 232 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=230),
|
| 233 |
+
meta={
|
| 234 |
+
**_COMMON_META,
|
| 235 |
+
"num_params": 50223688,
|
| 236 |
+
"_metrics": {
|
| 237 |
+
"ImageNet-1K": {
|
| 238 |
+
"acc@1": 83.616,
|
| 239 |
+
"acc@5": 96.650,
|
| 240 |
+
}
|
| 241 |
+
},
|
| 242 |
+
"_ops": 8.684,
|
| 243 |
+
"_file_size": 191.703,
|
| 244 |
+
},
|
| 245 |
+
)
|
| 246 |
+
DEFAULT = IMAGENET1K_V1
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ConvNeXt_Base_Weights(WeightsEnum):
|
| 250 |
+
IMAGENET1K_V1 = Weights(
|
| 251 |
+
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
|
| 252 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 253 |
+
meta={
|
| 254 |
+
**_COMMON_META,
|
| 255 |
+
"num_params": 88591464,
|
| 256 |
+
"_metrics": {
|
| 257 |
+
"ImageNet-1K": {
|
| 258 |
+
"acc@1": 84.062,
|
| 259 |
+
"acc@5": 96.870,
|
| 260 |
+
}
|
| 261 |
+
},
|
| 262 |
+
"_ops": 15.355,
|
| 263 |
+
"_file_size": 338.064,
|
| 264 |
+
},
|
| 265 |
+
)
|
| 266 |
+
DEFAULT = IMAGENET1K_V1
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class ConvNeXt_Large_Weights(WeightsEnum):
|
| 270 |
+
IMAGENET1K_V1 = Weights(
|
| 271 |
+
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
|
| 272 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 273 |
+
meta={
|
| 274 |
+
**_COMMON_META,
|
| 275 |
+
"num_params": 197767336,
|
| 276 |
+
"_metrics": {
|
| 277 |
+
"ImageNet-1K": {
|
| 278 |
+
"acc@1": 84.414,
|
| 279 |
+
"acc@5": 96.976,
|
| 280 |
+
}
|
| 281 |
+
},
|
| 282 |
+
"_ops": 34.361,
|
| 283 |
+
"_file_size": 754.537,
|
| 284 |
+
},
|
| 285 |
+
)
|
| 286 |
+
DEFAULT = IMAGENET1K_V1
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@register_model()
|
| 290 |
+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
|
| 291 |
+
def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
|
| 292 |
+
"""ConvNeXt Tiny model architecture from the
|
| 293 |
+
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
|
| 297 |
+
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
|
| 298 |
+
below for more details and possible values. By default, no pre-trained weights are used.
|
| 299 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 300 |
+
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
| 301 |
+
base class. Please refer to the `source code
|
| 302 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
| 303 |
+
for more details about this class.
|
| 304 |
+
|
| 305 |
+
.. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
|
| 306 |
+
:members:
|
| 307 |
+
"""
|
| 308 |
+
weights = ConvNeXt_Tiny_Weights.verify(weights)
|
| 309 |
+
|
| 310 |
+
block_setting = [
|
| 311 |
+
CNBlockConfig(96, 192, 3),
|
| 312 |
+
CNBlockConfig(192, 384, 3),
|
| 313 |
+
CNBlockConfig(384, 768, 9),
|
| 314 |
+
CNBlockConfig(768, None, 3),
|
| 315 |
+
]
|
| 316 |
+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
|
| 317 |
+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@register_model()
|
| 321 |
+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
|
| 322 |
+
def convnext_small(
|
| 323 |
+
*, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
|
| 324 |
+
) -> ConvNeXt:
|
| 325 |
+
"""ConvNeXt Small model architecture from the
|
| 326 |
+
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
|
| 330 |
+
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
|
| 331 |
+
below for more details and possible values. By default, no pre-trained weights are used.
|
| 332 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 333 |
+
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
| 334 |
+
base class. Please refer to the `source code
|
| 335 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
| 336 |
+
for more details about this class.
|
| 337 |
+
|
| 338 |
+
.. autoclass:: torchvision.models.ConvNeXt_Small_Weights
|
| 339 |
+
:members:
|
| 340 |
+
"""
|
| 341 |
+
weights = ConvNeXt_Small_Weights.verify(weights)
|
| 342 |
+
|
| 343 |
+
block_setting = [
|
| 344 |
+
CNBlockConfig(96, 192, 3),
|
| 345 |
+
CNBlockConfig(192, 384, 3),
|
| 346 |
+
CNBlockConfig(384, 768, 27),
|
| 347 |
+
CNBlockConfig(768, None, 3),
|
| 348 |
+
]
|
| 349 |
+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
|
| 350 |
+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
@register_model()
|
| 354 |
+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
|
| 355 |
+
def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
|
| 356 |
+
"""ConvNeXt Base model architecture from the
|
| 357 |
+
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
|
| 361 |
+
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
|
| 362 |
+
below for more details and possible values. By default, no pre-trained weights are used.
|
| 363 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 364 |
+
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
| 365 |
+
base class. Please refer to the `source code
|
| 366 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
| 367 |
+
for more details about this class.
|
| 368 |
+
|
| 369 |
+
.. autoclass:: torchvision.models.ConvNeXt_Base_Weights
|
| 370 |
+
:members:
|
| 371 |
+
"""
|
| 372 |
+
weights = ConvNeXt_Base_Weights.verify(weights)
|
| 373 |
+
|
| 374 |
+
block_setting = [
|
| 375 |
+
CNBlockConfig(128, 256, 3),
|
| 376 |
+
CNBlockConfig(256, 512, 3),
|
| 377 |
+
CNBlockConfig(512, 1024, 27),
|
| 378 |
+
CNBlockConfig(1024, None, 3),
|
| 379 |
+
]
|
| 380 |
+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
|
| 381 |
+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@register_model()
|
| 385 |
+
@handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
|
| 386 |
+
def convnext_large(
|
| 387 |
+
*, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
|
| 388 |
+
) -> ConvNeXt:
|
| 389 |
+
"""ConvNeXt Large model architecture from the
|
| 390 |
+
`A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
|
| 394 |
+
weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
|
| 395 |
+
below for more details and possible values. By default, no pre-trained weights are used.
|
| 396 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 397 |
+
**kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
|
| 398 |
+
base class. Please refer to the `source code
|
| 399 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
|
| 400 |
+
for more details about this class.
|
| 401 |
+
|
| 402 |
+
.. autoclass:: torchvision.models.ConvNeXt_Large_Weights
|
| 403 |
+
:members:
|
| 404 |
+
"""
|
| 405 |
+
weights = ConvNeXt_Large_Weights.verify(weights)
|
| 406 |
+
|
| 407 |
+
block_setting = [
|
| 408 |
+
CNBlockConfig(192, 384, 3),
|
| 409 |
+
CNBlockConfig(384, 768, 3),
|
| 410 |
+
CNBlockConfig(768, 1536, 27),
|
| 411 |
+
CNBlockConfig(1536, None, 3),
|
| 412 |
+
]
|
| 413 |
+
stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
|
| 414 |
+
return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
|
.venv/lib/python3.11/site-packages/torchvision/models/densenet.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.utils.checkpoint as cp
|
| 10 |
+
from torch import Tensor
|
| 11 |
+
|
| 12 |
+
from ..transforms._presets import ImageClassification
|
| 13 |
+
from ..utils import _log_api_usage_once
|
| 14 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 15 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 16 |
+
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"DenseNet",
|
| 20 |
+
"DenseNet121_Weights",
|
| 21 |
+
"DenseNet161_Weights",
|
| 22 |
+
"DenseNet169_Weights",
|
| 23 |
+
"DenseNet201_Weights",
|
| 24 |
+
"densenet121",
|
| 25 |
+
"densenet161",
|
| 26 |
+
"densenet169",
|
| 27 |
+
"densenet201",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class _DenseLayer(nn.Module):
|
| 32 |
+
def __init__(
|
| 33 |
+
self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
|
| 34 |
+
) -> None:
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.norm1 = nn.BatchNorm2d(num_input_features)
|
| 37 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 38 |
+
self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
|
| 39 |
+
|
| 40 |
+
self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
|
| 41 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 42 |
+
self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
|
| 43 |
+
|
| 44 |
+
self.drop_rate = float(drop_rate)
|
| 45 |
+
self.memory_efficient = memory_efficient
|
| 46 |
+
|
| 47 |
+
def bn_function(self, inputs: List[Tensor]) -> Tensor:
|
| 48 |
+
concated_features = torch.cat(inputs, 1)
|
| 49 |
+
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
|
| 50 |
+
return bottleneck_output
|
| 51 |
+
|
| 52 |
+
# todo: rewrite when torchscript supports any
|
| 53 |
+
def any_requires_grad(self, input: List[Tensor]) -> bool:
|
| 54 |
+
for tensor in input:
|
| 55 |
+
if tensor.requires_grad:
|
| 56 |
+
return True
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
@torch.jit.unused # noqa: T484
|
| 60 |
+
def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
|
| 61 |
+
def closure(*inputs):
|
| 62 |
+
return self.bn_function(inputs)
|
| 63 |
+
|
| 64 |
+
return cp.checkpoint(closure, *input, use_reentrant=False)
|
| 65 |
+
|
| 66 |
+
@torch.jit._overload_method # noqa: F811
|
| 67 |
+
def forward(self, input: List[Tensor]) -> Tensor: # noqa: F811
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
@torch.jit._overload_method # noqa: F811
|
| 71 |
+
def forward(self, input: Tensor) -> Tensor: # noqa: F811
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
# torchscript does not yet support *args, so we overload method
|
| 75 |
+
# allowing it to take either a List[Tensor] or single Tensor
|
| 76 |
+
def forward(self, input: Tensor) -> Tensor: # noqa: F811
|
| 77 |
+
if isinstance(input, Tensor):
|
| 78 |
+
prev_features = [input]
|
| 79 |
+
else:
|
| 80 |
+
prev_features = input
|
| 81 |
+
|
| 82 |
+
if self.memory_efficient and self.any_requires_grad(prev_features):
|
| 83 |
+
if torch.jit.is_scripting():
|
| 84 |
+
raise Exception("Memory Efficient not supported in JIT")
|
| 85 |
+
|
| 86 |
+
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
|
| 87 |
+
else:
|
| 88 |
+
bottleneck_output = self.bn_function(prev_features)
|
| 89 |
+
|
| 90 |
+
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
|
| 91 |
+
if self.drop_rate > 0:
|
| 92 |
+
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
|
| 93 |
+
return new_features
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class _DenseBlock(nn.ModuleDict):
|
| 97 |
+
_version = 2
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
num_layers: int,
|
| 102 |
+
num_input_features: int,
|
| 103 |
+
bn_size: int,
|
| 104 |
+
growth_rate: int,
|
| 105 |
+
drop_rate: float,
|
| 106 |
+
memory_efficient: bool = False,
|
| 107 |
+
) -> None:
|
| 108 |
+
super().__init__()
|
| 109 |
+
for i in range(num_layers):
|
| 110 |
+
layer = _DenseLayer(
|
| 111 |
+
num_input_features + i * growth_rate,
|
| 112 |
+
growth_rate=growth_rate,
|
| 113 |
+
bn_size=bn_size,
|
| 114 |
+
drop_rate=drop_rate,
|
| 115 |
+
memory_efficient=memory_efficient,
|
| 116 |
+
)
|
| 117 |
+
self.add_module("denselayer%d" % (i + 1), layer)
|
| 118 |
+
|
| 119 |
+
def forward(self, init_features: Tensor) -> Tensor:
|
| 120 |
+
features = [init_features]
|
| 121 |
+
for name, layer in self.items():
|
| 122 |
+
new_features = layer(features)
|
| 123 |
+
features.append(new_features)
|
| 124 |
+
return torch.cat(features, 1)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class _Transition(nn.Sequential):
|
| 128 |
+
def __init__(self, num_input_features: int, num_output_features: int) -> None:
|
| 129 |
+
super().__init__()
|
| 130 |
+
self.norm = nn.BatchNorm2d(num_input_features)
|
| 131 |
+
self.relu = nn.ReLU(inplace=True)
|
| 132 |
+
self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
|
| 133 |
+
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class DenseNet(nn.Module):
|
| 137 |
+
r"""Densenet-BC model class, based on
|
| 138 |
+
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
growth_rate (int) - how many filters to add each layer (`k` in paper)
|
| 142 |
+
block_config (list of 4 ints) - how many layers in each pooling block
|
| 143 |
+
num_init_features (int) - the number of filters to learn in the first convolution layer
|
| 144 |
+
bn_size (int) - multiplicative factor for number of bottle neck layers
|
| 145 |
+
(i.e. bn_size * k features in the bottleneck layer)
|
| 146 |
+
drop_rate (float) - dropout rate after each dense layer
|
| 147 |
+
num_classes (int) - number of classification classes
|
| 148 |
+
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
|
| 149 |
+
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
growth_rate: int = 32,
|
| 155 |
+
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
|
| 156 |
+
num_init_features: int = 64,
|
| 157 |
+
bn_size: int = 4,
|
| 158 |
+
drop_rate: float = 0,
|
| 159 |
+
num_classes: int = 1000,
|
| 160 |
+
memory_efficient: bool = False,
|
| 161 |
+
) -> None:
|
| 162 |
+
|
| 163 |
+
super().__init__()
|
| 164 |
+
_log_api_usage_once(self)
|
| 165 |
+
|
| 166 |
+
# First convolution
|
| 167 |
+
self.features = nn.Sequential(
|
| 168 |
+
OrderedDict(
|
| 169 |
+
[
|
| 170 |
+
("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
|
| 171 |
+
("norm0", nn.BatchNorm2d(num_init_features)),
|
| 172 |
+
("relu0", nn.ReLU(inplace=True)),
|
| 173 |
+
("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
|
| 174 |
+
]
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Each denseblock
|
| 179 |
+
num_features = num_init_features
|
| 180 |
+
for i, num_layers in enumerate(block_config):
|
| 181 |
+
block = _DenseBlock(
|
| 182 |
+
num_layers=num_layers,
|
| 183 |
+
num_input_features=num_features,
|
| 184 |
+
bn_size=bn_size,
|
| 185 |
+
growth_rate=growth_rate,
|
| 186 |
+
drop_rate=drop_rate,
|
| 187 |
+
memory_efficient=memory_efficient,
|
| 188 |
+
)
|
| 189 |
+
self.features.add_module("denseblock%d" % (i + 1), block)
|
| 190 |
+
num_features = num_features + num_layers * growth_rate
|
| 191 |
+
if i != len(block_config) - 1:
|
| 192 |
+
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
|
| 193 |
+
self.features.add_module("transition%d" % (i + 1), trans)
|
| 194 |
+
num_features = num_features // 2
|
| 195 |
+
|
| 196 |
+
# Final batch norm
|
| 197 |
+
self.features.add_module("norm5", nn.BatchNorm2d(num_features))
|
| 198 |
+
|
| 199 |
+
# Linear layer
|
| 200 |
+
self.classifier = nn.Linear(num_features, num_classes)
|
| 201 |
+
|
| 202 |
+
# Official init from torch repo.
|
| 203 |
+
for m in self.modules():
|
| 204 |
+
if isinstance(m, nn.Conv2d):
|
| 205 |
+
nn.init.kaiming_normal_(m.weight)
|
| 206 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 207 |
+
nn.init.constant_(m.weight, 1)
|
| 208 |
+
nn.init.constant_(m.bias, 0)
|
| 209 |
+
elif isinstance(m, nn.Linear):
|
| 210 |
+
nn.init.constant_(m.bias, 0)
|
| 211 |
+
|
| 212 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 213 |
+
features = self.features(x)
|
| 214 |
+
out = F.relu(features, inplace=True)
|
| 215 |
+
out = F.adaptive_avg_pool2d(out, (1, 1))
|
| 216 |
+
out = torch.flatten(out, 1)
|
| 217 |
+
out = self.classifier(out)
|
| 218 |
+
return out
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
|
| 222 |
+
# '.'s are no longer allowed in module names, but previous _DenseLayer
|
| 223 |
+
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
|
| 224 |
+
# They are also in the checkpoints in model_urls. This pattern is used
|
| 225 |
+
# to find such keys.
|
| 226 |
+
pattern = re.compile(
|
| 227 |
+
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
state_dict = weights.get_state_dict(progress=progress, check_hash=True)
|
| 231 |
+
for key in list(state_dict.keys()):
|
| 232 |
+
res = pattern.match(key)
|
| 233 |
+
if res:
|
| 234 |
+
new_key = res.group(1) + res.group(2)
|
| 235 |
+
state_dict[new_key] = state_dict[key]
|
| 236 |
+
del state_dict[key]
|
| 237 |
+
model.load_state_dict(state_dict)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _densenet(
|
| 241 |
+
growth_rate: int,
|
| 242 |
+
block_config: Tuple[int, int, int, int],
|
| 243 |
+
num_init_features: int,
|
| 244 |
+
weights: Optional[WeightsEnum],
|
| 245 |
+
progress: bool,
|
| 246 |
+
**kwargs: Any,
|
| 247 |
+
) -> DenseNet:
|
| 248 |
+
if weights is not None:
|
| 249 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 250 |
+
|
| 251 |
+
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
|
| 252 |
+
|
| 253 |
+
if weights is not None:
|
| 254 |
+
_load_state_dict(model=model, weights=weights, progress=progress)
|
| 255 |
+
|
| 256 |
+
return model
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
_COMMON_META = {
|
| 260 |
+
"min_size": (29, 29),
|
| 261 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 262 |
+
"recipe": "https://github.com/pytorch/vision/pull/116",
|
| 263 |
+
"_docs": """These weights are ported from LuaTorch.""",
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class DenseNet121_Weights(WeightsEnum):
|
| 268 |
+
IMAGENET1K_V1 = Weights(
|
| 269 |
+
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
|
| 270 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 271 |
+
meta={
|
| 272 |
+
**_COMMON_META,
|
| 273 |
+
"num_params": 7978856,
|
| 274 |
+
"_metrics": {
|
| 275 |
+
"ImageNet-1K": {
|
| 276 |
+
"acc@1": 74.434,
|
| 277 |
+
"acc@5": 91.972,
|
| 278 |
+
}
|
| 279 |
+
},
|
| 280 |
+
"_ops": 2.834,
|
| 281 |
+
"_file_size": 30.845,
|
| 282 |
+
},
|
| 283 |
+
)
|
| 284 |
+
DEFAULT = IMAGENET1K_V1
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class DenseNet161_Weights(WeightsEnum):
|
| 288 |
+
IMAGENET1K_V1 = Weights(
|
| 289 |
+
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
|
| 290 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 291 |
+
meta={
|
| 292 |
+
**_COMMON_META,
|
| 293 |
+
"num_params": 28681000,
|
| 294 |
+
"_metrics": {
|
| 295 |
+
"ImageNet-1K": {
|
| 296 |
+
"acc@1": 77.138,
|
| 297 |
+
"acc@5": 93.560,
|
| 298 |
+
}
|
| 299 |
+
},
|
| 300 |
+
"_ops": 7.728,
|
| 301 |
+
"_file_size": 110.369,
|
| 302 |
+
},
|
| 303 |
+
)
|
| 304 |
+
DEFAULT = IMAGENET1K_V1
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
class DenseNet169_Weights(WeightsEnum):
|
| 308 |
+
IMAGENET1K_V1 = Weights(
|
| 309 |
+
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
|
| 310 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 311 |
+
meta={
|
| 312 |
+
**_COMMON_META,
|
| 313 |
+
"num_params": 14149480,
|
| 314 |
+
"_metrics": {
|
| 315 |
+
"ImageNet-1K": {
|
| 316 |
+
"acc@1": 75.600,
|
| 317 |
+
"acc@5": 92.806,
|
| 318 |
+
}
|
| 319 |
+
},
|
| 320 |
+
"_ops": 3.36,
|
| 321 |
+
"_file_size": 54.708,
|
| 322 |
+
},
|
| 323 |
+
)
|
| 324 |
+
DEFAULT = IMAGENET1K_V1
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class DenseNet201_Weights(WeightsEnum):
|
| 328 |
+
IMAGENET1K_V1 = Weights(
|
| 329 |
+
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
|
| 330 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 331 |
+
meta={
|
| 332 |
+
**_COMMON_META,
|
| 333 |
+
"num_params": 20013928,
|
| 334 |
+
"_metrics": {
|
| 335 |
+
"ImageNet-1K": {
|
| 336 |
+
"acc@1": 76.896,
|
| 337 |
+
"acc@5": 93.370,
|
| 338 |
+
}
|
| 339 |
+
},
|
| 340 |
+
"_ops": 4.291,
|
| 341 |
+
"_file_size": 77.373,
|
| 342 |
+
},
|
| 343 |
+
)
|
| 344 |
+
DEFAULT = IMAGENET1K_V1
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@register_model()
|
| 348 |
+
@handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
|
| 349 |
+
def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
|
| 350 |
+
r"""Densenet-121 model from
|
| 351 |
+
`Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
weights (:class:`~torchvision.models.DenseNet121_Weights`, optional): The
|
| 355 |
+
pretrained weights to use. See
|
| 356 |
+
:class:`~torchvision.models.DenseNet121_Weights` below for
|
| 357 |
+
more details, and possible values. By default, no pre-trained
|
| 358 |
+
weights are used.
|
| 359 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 360 |
+
**kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
|
| 361 |
+
base class. Please refer to the `source code
|
| 362 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
|
| 363 |
+
for more details about this class.
|
| 364 |
+
|
| 365 |
+
.. autoclass:: torchvision.models.DenseNet121_Weights
|
| 366 |
+
:members:
|
| 367 |
+
"""
|
| 368 |
+
weights = DenseNet121_Weights.verify(weights)
|
| 369 |
+
|
| 370 |
+
return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
@register_model()
|
| 374 |
+
@handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
|
| 375 |
+
def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
|
| 376 |
+
r"""Densenet-161 model from
|
| 377 |
+
`Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
weights (:class:`~torchvision.models.DenseNet161_Weights`, optional): The
|
| 381 |
+
pretrained weights to use. See
|
| 382 |
+
:class:`~torchvision.models.DenseNet161_Weights` below for
|
| 383 |
+
more details, and possible values. By default, no pre-trained
|
| 384 |
+
weights are used.
|
| 385 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 386 |
+
**kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
|
| 387 |
+
base class. Please refer to the `source code
|
| 388 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
|
| 389 |
+
for more details about this class.
|
| 390 |
+
|
| 391 |
+
.. autoclass:: torchvision.models.DenseNet161_Weights
|
| 392 |
+
:members:
|
| 393 |
+
"""
|
| 394 |
+
weights = DenseNet161_Weights.verify(weights)
|
| 395 |
+
|
| 396 |
+
return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
@register_model()
|
| 400 |
+
@handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
|
| 401 |
+
def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
|
| 402 |
+
r"""Densenet-169 model from
|
| 403 |
+
`Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
weights (:class:`~torchvision.models.DenseNet169_Weights`, optional): The
|
| 407 |
+
pretrained weights to use. See
|
| 408 |
+
:class:`~torchvision.models.DenseNet169_Weights` below for
|
| 409 |
+
more details, and possible values. By default, no pre-trained
|
| 410 |
+
weights are used.
|
| 411 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 412 |
+
**kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
|
| 413 |
+
base class. Please refer to the `source code
|
| 414 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
|
| 415 |
+
for more details about this class.
|
| 416 |
+
|
| 417 |
+
.. autoclass:: torchvision.models.DenseNet169_Weights
|
| 418 |
+
:members:
|
| 419 |
+
"""
|
| 420 |
+
weights = DenseNet169_Weights.verify(weights)
|
| 421 |
+
|
| 422 |
+
return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@register_model()
|
| 426 |
+
@handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
|
| 427 |
+
def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
|
| 428 |
+
r"""Densenet-201 model from
|
| 429 |
+
`Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
|
| 430 |
+
|
| 431 |
+
Args:
|
| 432 |
+
weights (:class:`~torchvision.models.DenseNet201_Weights`, optional): The
|
| 433 |
+
pretrained weights to use. See
|
| 434 |
+
:class:`~torchvision.models.DenseNet201_Weights` below for
|
| 435 |
+
more details, and possible values. By default, no pre-trained
|
| 436 |
+
weights are used.
|
| 437 |
+
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
|
| 438 |
+
**kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
|
| 439 |
+
base class. Please refer to the `source code
|
| 440 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
|
| 441 |
+
for more details about this class.
|
| 442 |
+
|
| 443 |
+
.. autoclass:: torchvision.models.DenseNet201_Weights
|
| 444 |
+
:members:
|
| 445 |
+
"""
|
| 446 |
+
weights = DenseNet201_Weights.verify(weights)
|
| 447 |
+
|
| 448 |
+
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
|
.venv/lib/python3.11/site-packages/torchvision/models/efficientnet.py
ADDED
|
@@ -0,0 +1,1131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from functools import partial
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn, Tensor
|
| 9 |
+
from torchvision.ops import StochasticDepth
|
| 10 |
+
|
| 11 |
+
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
|
| 12 |
+
from ..transforms._presets import ImageClassification, InterpolationMode
|
| 13 |
+
from ..utils import _log_api_usage_once
|
| 14 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 15 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 16 |
+
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"EfficientNet",
|
| 21 |
+
"EfficientNet_B0_Weights",
|
| 22 |
+
"EfficientNet_B1_Weights",
|
| 23 |
+
"EfficientNet_B2_Weights",
|
| 24 |
+
"EfficientNet_B3_Weights",
|
| 25 |
+
"EfficientNet_B4_Weights",
|
| 26 |
+
"EfficientNet_B5_Weights",
|
| 27 |
+
"EfficientNet_B6_Weights",
|
| 28 |
+
"EfficientNet_B7_Weights",
|
| 29 |
+
"EfficientNet_V2_S_Weights",
|
| 30 |
+
"EfficientNet_V2_M_Weights",
|
| 31 |
+
"EfficientNet_V2_L_Weights",
|
| 32 |
+
"efficientnet_b0",
|
| 33 |
+
"efficientnet_b1",
|
| 34 |
+
"efficientnet_b2",
|
| 35 |
+
"efficientnet_b3",
|
| 36 |
+
"efficientnet_b4",
|
| 37 |
+
"efficientnet_b5",
|
| 38 |
+
"efficientnet_b6",
|
| 39 |
+
"efficientnet_b7",
|
| 40 |
+
"efficientnet_v2_s",
|
| 41 |
+
"efficientnet_v2_m",
|
| 42 |
+
"efficientnet_v2_l",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class _MBConvConfig:
|
| 48 |
+
expand_ratio: float
|
| 49 |
+
kernel: int
|
| 50 |
+
stride: int
|
| 51 |
+
input_channels: int
|
| 52 |
+
out_channels: int
|
| 53 |
+
num_layers: int
|
| 54 |
+
block: Callable[..., nn.Module]
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
|
| 58 |
+
return _make_divisible(channels * width_mult, 8, min_value)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MBConvConfig(_MBConvConfig):
|
| 62 |
+
# Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
expand_ratio: float,
|
| 66 |
+
kernel: int,
|
| 67 |
+
stride: int,
|
| 68 |
+
input_channels: int,
|
| 69 |
+
out_channels: int,
|
| 70 |
+
num_layers: int,
|
| 71 |
+
width_mult: float = 1.0,
|
| 72 |
+
depth_mult: float = 1.0,
|
| 73 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
| 74 |
+
) -> None:
|
| 75 |
+
input_channels = self.adjust_channels(input_channels, width_mult)
|
| 76 |
+
out_channels = self.adjust_channels(out_channels, width_mult)
|
| 77 |
+
num_layers = self.adjust_depth(num_layers, depth_mult)
|
| 78 |
+
if block is None:
|
| 79 |
+
block = MBConv
|
| 80 |
+
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def adjust_depth(num_layers: int, depth_mult: float):
|
| 84 |
+
return int(math.ceil(num_layers * depth_mult))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class FusedMBConvConfig(_MBConvConfig):
|
| 88 |
+
# Stores information listed at Table 4 of the EfficientNetV2 paper
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
expand_ratio: float,
|
| 92 |
+
kernel: int,
|
| 93 |
+
stride: int,
|
| 94 |
+
input_channels: int,
|
| 95 |
+
out_channels: int,
|
| 96 |
+
num_layers: int,
|
| 97 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
| 98 |
+
) -> None:
|
| 99 |
+
if block is None:
|
| 100 |
+
block = FusedMBConv
|
| 101 |
+
super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class MBConv(nn.Module):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
cnf: MBConvConfig,
|
| 108 |
+
stochastic_depth_prob: float,
|
| 109 |
+
norm_layer: Callable[..., nn.Module],
|
| 110 |
+
se_layer: Callable[..., nn.Module] = SqueezeExcitation,
|
| 111 |
+
) -> None:
|
| 112 |
+
super().__init__()
|
| 113 |
+
|
| 114 |
+
if not (1 <= cnf.stride <= 2):
|
| 115 |
+
raise ValueError("illegal stride value")
|
| 116 |
+
|
| 117 |
+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
|
| 118 |
+
|
| 119 |
+
layers: List[nn.Module] = []
|
| 120 |
+
activation_layer = nn.SiLU
|
| 121 |
+
|
| 122 |
+
# expand
|
| 123 |
+
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
|
| 124 |
+
if expanded_channels != cnf.input_channels:
|
| 125 |
+
layers.append(
|
| 126 |
+
Conv2dNormActivation(
|
| 127 |
+
cnf.input_channels,
|
| 128 |
+
expanded_channels,
|
| 129 |
+
kernel_size=1,
|
| 130 |
+
norm_layer=norm_layer,
|
| 131 |
+
activation_layer=activation_layer,
|
| 132 |
+
)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# depthwise
|
| 136 |
+
layers.append(
|
| 137 |
+
Conv2dNormActivation(
|
| 138 |
+
expanded_channels,
|
| 139 |
+
expanded_channels,
|
| 140 |
+
kernel_size=cnf.kernel,
|
| 141 |
+
stride=cnf.stride,
|
| 142 |
+
groups=expanded_channels,
|
| 143 |
+
norm_layer=norm_layer,
|
| 144 |
+
activation_layer=activation_layer,
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# squeeze and excitation
|
| 149 |
+
squeeze_channels = max(1, cnf.input_channels // 4)
|
| 150 |
+
layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
|
| 151 |
+
|
| 152 |
+
# project
|
| 153 |
+
layers.append(
|
| 154 |
+
Conv2dNormActivation(
|
| 155 |
+
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.block = nn.Sequential(*layers)
|
| 160 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
| 161 |
+
self.out_channels = cnf.out_channels
|
| 162 |
+
|
| 163 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 164 |
+
result = self.block(input)
|
| 165 |
+
if self.use_res_connect:
|
| 166 |
+
result = self.stochastic_depth(result)
|
| 167 |
+
result += input
|
| 168 |
+
return result
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class FusedMBConv(nn.Module):
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
cnf: FusedMBConvConfig,
|
| 175 |
+
stochastic_depth_prob: float,
|
| 176 |
+
norm_layer: Callable[..., nn.Module],
|
| 177 |
+
) -> None:
|
| 178 |
+
super().__init__()
|
| 179 |
+
|
| 180 |
+
if not (1 <= cnf.stride <= 2):
|
| 181 |
+
raise ValueError("illegal stride value")
|
| 182 |
+
|
| 183 |
+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
|
| 184 |
+
|
| 185 |
+
layers: List[nn.Module] = []
|
| 186 |
+
activation_layer = nn.SiLU
|
| 187 |
+
|
| 188 |
+
expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
|
| 189 |
+
if expanded_channels != cnf.input_channels:
|
| 190 |
+
# fused expand
|
| 191 |
+
layers.append(
|
| 192 |
+
Conv2dNormActivation(
|
| 193 |
+
cnf.input_channels,
|
| 194 |
+
expanded_channels,
|
| 195 |
+
kernel_size=cnf.kernel,
|
| 196 |
+
stride=cnf.stride,
|
| 197 |
+
norm_layer=norm_layer,
|
| 198 |
+
activation_layer=activation_layer,
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# project
|
| 203 |
+
layers.append(
|
| 204 |
+
Conv2dNormActivation(
|
| 205 |
+
expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
layers.append(
|
| 210 |
+
Conv2dNormActivation(
|
| 211 |
+
cnf.input_channels,
|
| 212 |
+
cnf.out_channels,
|
| 213 |
+
kernel_size=cnf.kernel,
|
| 214 |
+
stride=cnf.stride,
|
| 215 |
+
norm_layer=norm_layer,
|
| 216 |
+
activation_layer=activation_layer,
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
self.block = nn.Sequential(*layers)
|
| 221 |
+
self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
|
| 222 |
+
self.out_channels = cnf.out_channels
|
| 223 |
+
|
| 224 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 225 |
+
result = self.block(input)
|
| 226 |
+
if self.use_res_connect:
|
| 227 |
+
result = self.stochastic_depth(result)
|
| 228 |
+
result += input
|
| 229 |
+
return result
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class EfficientNet(nn.Module):
|
| 233 |
+
def __init__(
|
| 234 |
+
self,
|
| 235 |
+
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
|
| 236 |
+
dropout: float,
|
| 237 |
+
stochastic_depth_prob: float = 0.2,
|
| 238 |
+
num_classes: int = 1000,
|
| 239 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 240 |
+
last_channel: Optional[int] = None,
|
| 241 |
+
) -> None:
|
| 242 |
+
"""
|
| 243 |
+
EfficientNet V1 and V2 main class
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
|
| 247 |
+
dropout (float): The droupout probability
|
| 248 |
+
stochastic_depth_prob (float): The stochastic depth probability
|
| 249 |
+
num_classes (int): Number of classes
|
| 250 |
+
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
|
| 251 |
+
last_channel (int): The number of channels on the penultimate layer
|
| 252 |
+
"""
|
| 253 |
+
super().__init__()
|
| 254 |
+
_log_api_usage_once(self)
|
| 255 |
+
|
| 256 |
+
if not inverted_residual_setting:
|
| 257 |
+
raise ValueError("The inverted_residual_setting should not be empty")
|
| 258 |
+
elif not (
|
| 259 |
+
isinstance(inverted_residual_setting, Sequence)
|
| 260 |
+
and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
|
| 261 |
+
):
|
| 262 |
+
raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
|
| 263 |
+
|
| 264 |
+
if norm_layer is None:
|
| 265 |
+
norm_layer = nn.BatchNorm2d
|
| 266 |
+
|
| 267 |
+
layers: List[nn.Module] = []
|
| 268 |
+
|
| 269 |
+
# building first layer
|
| 270 |
+
firstconv_output_channels = inverted_residual_setting[0].input_channels
|
| 271 |
+
layers.append(
|
| 272 |
+
Conv2dNormActivation(
|
| 273 |
+
3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
|
| 274 |
+
)
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# building inverted residual blocks
|
| 278 |
+
total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
|
| 279 |
+
stage_block_id = 0
|
| 280 |
+
for cnf in inverted_residual_setting:
|
| 281 |
+
stage: List[nn.Module] = []
|
| 282 |
+
for _ in range(cnf.num_layers):
|
| 283 |
+
# copy to avoid modifications. shallow copy is enough
|
| 284 |
+
block_cnf = copy.copy(cnf)
|
| 285 |
+
|
| 286 |
+
# overwrite info if not the first conv in the stage
|
| 287 |
+
if stage:
|
| 288 |
+
block_cnf.input_channels = block_cnf.out_channels
|
| 289 |
+
block_cnf.stride = 1
|
| 290 |
+
|
| 291 |
+
# adjust stochastic depth probability based on the depth of the stage block
|
| 292 |
+
sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
|
| 293 |
+
|
| 294 |
+
stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
|
| 295 |
+
stage_block_id += 1
|
| 296 |
+
|
| 297 |
+
layers.append(nn.Sequential(*stage))
|
| 298 |
+
|
| 299 |
+
# building last several layers
|
| 300 |
+
lastconv_input_channels = inverted_residual_setting[-1].out_channels
|
| 301 |
+
lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
|
| 302 |
+
layers.append(
|
| 303 |
+
Conv2dNormActivation(
|
| 304 |
+
lastconv_input_channels,
|
| 305 |
+
lastconv_output_channels,
|
| 306 |
+
kernel_size=1,
|
| 307 |
+
norm_layer=norm_layer,
|
| 308 |
+
activation_layer=nn.SiLU,
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
self.features = nn.Sequential(*layers)
|
| 313 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 314 |
+
self.classifier = nn.Sequential(
|
| 315 |
+
nn.Dropout(p=dropout, inplace=True),
|
| 316 |
+
nn.Linear(lastconv_output_channels, num_classes),
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
for m in self.modules():
|
| 320 |
+
if isinstance(m, nn.Conv2d):
|
| 321 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
| 322 |
+
if m.bias is not None:
|
| 323 |
+
nn.init.zeros_(m.bias)
|
| 324 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 325 |
+
nn.init.ones_(m.weight)
|
| 326 |
+
nn.init.zeros_(m.bias)
|
| 327 |
+
elif isinstance(m, nn.Linear):
|
| 328 |
+
init_range = 1.0 / math.sqrt(m.out_features)
|
| 329 |
+
nn.init.uniform_(m.weight, -init_range, init_range)
|
| 330 |
+
nn.init.zeros_(m.bias)
|
| 331 |
+
|
| 332 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 333 |
+
x = self.features(x)
|
| 334 |
+
|
| 335 |
+
x = self.avgpool(x)
|
| 336 |
+
x = torch.flatten(x, 1)
|
| 337 |
+
|
| 338 |
+
x = self.classifier(x)
|
| 339 |
+
|
| 340 |
+
return x
|
| 341 |
+
|
| 342 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 343 |
+
return self._forward_impl(x)
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _efficientnet(
|
| 347 |
+
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
|
| 348 |
+
dropout: float,
|
| 349 |
+
last_channel: Optional[int],
|
| 350 |
+
weights: Optional[WeightsEnum],
|
| 351 |
+
progress: bool,
|
| 352 |
+
**kwargs: Any,
|
| 353 |
+
) -> EfficientNet:
|
| 354 |
+
if weights is not None:
|
| 355 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 356 |
+
|
| 357 |
+
model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
|
| 358 |
+
|
| 359 |
+
if weights is not None:
|
| 360 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 361 |
+
|
| 362 |
+
return model
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _efficientnet_conf(
|
| 366 |
+
arch: str,
|
| 367 |
+
**kwargs: Any,
|
| 368 |
+
) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
|
| 369 |
+
inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
|
| 370 |
+
if arch.startswith("efficientnet_b"):
|
| 371 |
+
bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
|
| 372 |
+
inverted_residual_setting = [
|
| 373 |
+
bneck_conf(1, 3, 1, 32, 16, 1),
|
| 374 |
+
bneck_conf(6, 3, 2, 16, 24, 2),
|
| 375 |
+
bneck_conf(6, 5, 2, 24, 40, 2),
|
| 376 |
+
bneck_conf(6, 3, 2, 40, 80, 3),
|
| 377 |
+
bneck_conf(6, 5, 1, 80, 112, 3),
|
| 378 |
+
bneck_conf(6, 5, 2, 112, 192, 4),
|
| 379 |
+
bneck_conf(6, 3, 1, 192, 320, 1),
|
| 380 |
+
]
|
| 381 |
+
last_channel = None
|
| 382 |
+
elif arch.startswith("efficientnet_v2_s"):
|
| 383 |
+
inverted_residual_setting = [
|
| 384 |
+
FusedMBConvConfig(1, 3, 1, 24, 24, 2),
|
| 385 |
+
FusedMBConvConfig(4, 3, 2, 24, 48, 4),
|
| 386 |
+
FusedMBConvConfig(4, 3, 2, 48, 64, 4),
|
| 387 |
+
MBConvConfig(4, 3, 2, 64, 128, 6),
|
| 388 |
+
MBConvConfig(6, 3, 1, 128, 160, 9),
|
| 389 |
+
MBConvConfig(6, 3, 2, 160, 256, 15),
|
| 390 |
+
]
|
| 391 |
+
last_channel = 1280
|
| 392 |
+
elif arch.startswith("efficientnet_v2_m"):
|
| 393 |
+
inverted_residual_setting = [
|
| 394 |
+
FusedMBConvConfig(1, 3, 1, 24, 24, 3),
|
| 395 |
+
FusedMBConvConfig(4, 3, 2, 24, 48, 5),
|
| 396 |
+
FusedMBConvConfig(4, 3, 2, 48, 80, 5),
|
| 397 |
+
MBConvConfig(4, 3, 2, 80, 160, 7),
|
| 398 |
+
MBConvConfig(6, 3, 1, 160, 176, 14),
|
| 399 |
+
MBConvConfig(6, 3, 2, 176, 304, 18),
|
| 400 |
+
MBConvConfig(6, 3, 1, 304, 512, 5),
|
| 401 |
+
]
|
| 402 |
+
last_channel = 1280
|
| 403 |
+
elif arch.startswith("efficientnet_v2_l"):
|
| 404 |
+
inverted_residual_setting = [
|
| 405 |
+
FusedMBConvConfig(1, 3, 1, 32, 32, 4),
|
| 406 |
+
FusedMBConvConfig(4, 3, 2, 32, 64, 7),
|
| 407 |
+
FusedMBConvConfig(4, 3, 2, 64, 96, 7),
|
| 408 |
+
MBConvConfig(4, 3, 2, 96, 192, 10),
|
| 409 |
+
MBConvConfig(6, 3, 1, 192, 224, 19),
|
| 410 |
+
MBConvConfig(6, 3, 2, 224, 384, 25),
|
| 411 |
+
MBConvConfig(6, 3, 1, 384, 640, 7),
|
| 412 |
+
]
|
| 413 |
+
last_channel = 1280
|
| 414 |
+
else:
|
| 415 |
+
raise ValueError(f"Unsupported model type {arch}")
|
| 416 |
+
|
| 417 |
+
return inverted_residual_setting, last_channel
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
_COMMON_META: Dict[str, Any] = {
|
| 421 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
_COMMON_META_V1 = {
|
| 426 |
+
**_COMMON_META,
|
| 427 |
+
"min_size": (1, 1),
|
| 428 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
_COMMON_META_V2 = {
|
| 433 |
+
**_COMMON_META,
|
| 434 |
+
"min_size": (33, 33),
|
| 435 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
class EfficientNet_B0_Weights(WeightsEnum):
|
| 440 |
+
IMAGENET1K_V1 = Weights(
|
| 441 |
+
# Weights ported from https://github.com/rwightman/pytorch-image-models/
|
| 442 |
+
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth",
|
| 443 |
+
transforms=partial(
|
| 444 |
+
ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
|
| 445 |
+
),
|
| 446 |
+
meta={
|
| 447 |
+
**_COMMON_META_V1,
|
| 448 |
+
"num_params": 5288548,
|
| 449 |
+
"_metrics": {
|
| 450 |
+
"ImageNet-1K": {
|
| 451 |
+
"acc@1": 77.692,
|
| 452 |
+
"acc@5": 93.532,
|
| 453 |
+
}
|
| 454 |
+
},
|
| 455 |
+
"_ops": 0.386,
|
| 456 |
+
"_file_size": 20.451,
|
| 457 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 458 |
+
},
|
| 459 |
+
)
|
| 460 |
+
DEFAULT = IMAGENET1K_V1
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class EfficientNet_B1_Weights(WeightsEnum):
|
| 464 |
+
IMAGENET1K_V1 = Weights(
|
| 465 |
+
# Weights ported from https://github.com/rwightman/pytorch-image-models/
|
| 466 |
+
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
|
| 467 |
+
transforms=partial(
|
| 468 |
+
ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
|
| 469 |
+
),
|
| 470 |
+
meta={
|
| 471 |
+
**_COMMON_META_V1,
|
| 472 |
+
"num_params": 7794184,
|
| 473 |
+
"_metrics": {
|
| 474 |
+
"ImageNet-1K": {
|
| 475 |
+
"acc@1": 78.642,
|
| 476 |
+
"acc@5": 94.186,
|
| 477 |
+
}
|
| 478 |
+
},
|
| 479 |
+
"_ops": 0.687,
|
| 480 |
+
"_file_size": 30.134,
|
| 481 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 482 |
+
},
|
| 483 |
+
)
|
| 484 |
+
IMAGENET1K_V2 = Weights(
|
| 485 |
+
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
|
| 486 |
+
transforms=partial(
|
| 487 |
+
ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
|
| 488 |
+
),
|
| 489 |
+
meta={
|
| 490 |
+
**_COMMON_META_V1,
|
| 491 |
+
"num_params": 7794184,
|
| 492 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
|
| 493 |
+
"_metrics": {
|
| 494 |
+
"ImageNet-1K": {
|
| 495 |
+
"acc@1": 79.838,
|
| 496 |
+
"acc@5": 94.934,
|
| 497 |
+
}
|
| 498 |
+
},
|
| 499 |
+
"_ops": 0.687,
|
| 500 |
+
"_file_size": 30.136,
|
| 501 |
+
"_docs": """
|
| 502 |
+
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
| 503 |
+
`new training recipe
|
| 504 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 505 |
+
""",
|
| 506 |
+
},
|
| 507 |
+
)
|
| 508 |
+
DEFAULT = IMAGENET1K_V2
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
class EfficientNet_B2_Weights(WeightsEnum):
|
| 512 |
+
IMAGENET1K_V1 = Weights(
|
| 513 |
+
# Weights ported from https://github.com/rwightman/pytorch-image-models/
|
| 514 |
+
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth",
|
| 515 |
+
transforms=partial(
|
| 516 |
+
ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
|
| 517 |
+
),
|
| 518 |
+
meta={
|
| 519 |
+
**_COMMON_META_V1,
|
| 520 |
+
"num_params": 9109994,
|
| 521 |
+
"_metrics": {
|
| 522 |
+
"ImageNet-1K": {
|
| 523 |
+
"acc@1": 80.608,
|
| 524 |
+
"acc@5": 95.310,
|
| 525 |
+
}
|
| 526 |
+
},
|
| 527 |
+
"_ops": 1.088,
|
| 528 |
+
"_file_size": 35.174,
|
| 529 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 530 |
+
},
|
| 531 |
+
)
|
| 532 |
+
DEFAULT = IMAGENET1K_V1
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class EfficientNet_B3_Weights(WeightsEnum):
|
| 536 |
+
IMAGENET1K_V1 = Weights(
|
| 537 |
+
# Weights ported from https://github.com/rwightman/pytorch-image-models/
|
| 538 |
+
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth",
|
| 539 |
+
transforms=partial(
|
| 540 |
+
ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
|
| 541 |
+
),
|
| 542 |
+
meta={
|
| 543 |
+
**_COMMON_META_V1,
|
| 544 |
+
"num_params": 12233232,
|
| 545 |
+
"_metrics": {
|
| 546 |
+
"ImageNet-1K": {
|
| 547 |
+
"acc@1": 82.008,
|
| 548 |
+
"acc@5": 96.054,
|
| 549 |
+
}
|
| 550 |
+
},
|
| 551 |
+
"_ops": 1.827,
|
| 552 |
+
"_file_size": 47.184,
|
| 553 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 554 |
+
},
|
| 555 |
+
)
|
| 556 |
+
DEFAULT = IMAGENET1K_V1
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class EfficientNet_B4_Weights(WeightsEnum):
|
| 560 |
+
IMAGENET1K_V1 = Weights(
|
| 561 |
+
# Weights ported from https://github.com/rwightman/pytorch-image-models/
|
| 562 |
+
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth",
|
| 563 |
+
transforms=partial(
|
| 564 |
+
ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
|
| 565 |
+
),
|
| 566 |
+
meta={
|
| 567 |
+
**_COMMON_META_V1,
|
| 568 |
+
"num_params": 19341616,
|
| 569 |
+
"_metrics": {
|
| 570 |
+
"ImageNet-1K": {
|
| 571 |
+
"acc@1": 83.384,
|
| 572 |
+
"acc@5": 96.594,
|
| 573 |
+
}
|
| 574 |
+
},
|
| 575 |
+
"_ops": 4.394,
|
| 576 |
+
"_file_size": 74.489,
|
| 577 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 578 |
+
},
|
| 579 |
+
)
|
| 580 |
+
DEFAULT = IMAGENET1K_V1
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
class EfficientNet_B5_Weights(WeightsEnum):
|
| 584 |
+
IMAGENET1K_V1 = Weights(
|
| 585 |
+
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
|
| 586 |
+
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth",
|
| 587 |
+
transforms=partial(
|
| 588 |
+
ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
|
| 589 |
+
),
|
| 590 |
+
meta={
|
| 591 |
+
**_COMMON_META_V1,
|
| 592 |
+
"num_params": 30389784,
|
| 593 |
+
"_metrics": {
|
| 594 |
+
"ImageNet-1K": {
|
| 595 |
+
"acc@1": 83.444,
|
| 596 |
+
"acc@5": 96.628,
|
| 597 |
+
}
|
| 598 |
+
},
|
| 599 |
+
"_ops": 10.266,
|
| 600 |
+
"_file_size": 116.864,
|
| 601 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 602 |
+
},
|
| 603 |
+
)
|
| 604 |
+
DEFAULT = IMAGENET1K_V1
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
class EfficientNet_B6_Weights(WeightsEnum):
|
| 608 |
+
IMAGENET1K_V1 = Weights(
|
| 609 |
+
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
|
| 610 |
+
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-24a108a5.pth",
|
| 611 |
+
transforms=partial(
|
| 612 |
+
ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
|
| 613 |
+
),
|
| 614 |
+
meta={
|
| 615 |
+
**_COMMON_META_V1,
|
| 616 |
+
"num_params": 43040704,
|
| 617 |
+
"_metrics": {
|
| 618 |
+
"ImageNet-1K": {
|
| 619 |
+
"acc@1": 84.008,
|
| 620 |
+
"acc@5": 96.916,
|
| 621 |
+
}
|
| 622 |
+
},
|
| 623 |
+
"_ops": 19.068,
|
| 624 |
+
"_file_size": 165.362,
|
| 625 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 626 |
+
},
|
| 627 |
+
)
|
| 628 |
+
DEFAULT = IMAGENET1K_V1
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
class EfficientNet_B7_Weights(WeightsEnum):
|
| 632 |
+
IMAGENET1K_V1 = Weights(
|
| 633 |
+
# Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
|
| 634 |
+
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-c5b4e57e.pth",
|
| 635 |
+
transforms=partial(
|
| 636 |
+
ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
|
| 637 |
+
),
|
| 638 |
+
meta={
|
| 639 |
+
**_COMMON_META_V1,
|
| 640 |
+
"num_params": 66347960,
|
| 641 |
+
"_metrics": {
|
| 642 |
+
"ImageNet-1K": {
|
| 643 |
+
"acc@1": 84.122,
|
| 644 |
+
"acc@5": 96.908,
|
| 645 |
+
}
|
| 646 |
+
},
|
| 647 |
+
"_ops": 37.746,
|
| 648 |
+
"_file_size": 254.675,
|
| 649 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 650 |
+
},
|
| 651 |
+
)
|
| 652 |
+
DEFAULT = IMAGENET1K_V1
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class EfficientNet_V2_S_Weights(WeightsEnum):
|
| 656 |
+
IMAGENET1K_V1 = Weights(
|
| 657 |
+
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
|
| 658 |
+
transforms=partial(
|
| 659 |
+
ImageClassification,
|
| 660 |
+
crop_size=384,
|
| 661 |
+
resize_size=384,
|
| 662 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 663 |
+
),
|
| 664 |
+
meta={
|
| 665 |
+
**_COMMON_META_V2,
|
| 666 |
+
"num_params": 21458488,
|
| 667 |
+
"_metrics": {
|
| 668 |
+
"ImageNet-1K": {
|
| 669 |
+
"acc@1": 84.228,
|
| 670 |
+
"acc@5": 96.878,
|
| 671 |
+
}
|
| 672 |
+
},
|
| 673 |
+
"_ops": 8.366,
|
| 674 |
+
"_file_size": 82.704,
|
| 675 |
+
"_docs": """
|
| 676 |
+
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
| 677 |
+
`new training recipe
|
| 678 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 679 |
+
""",
|
| 680 |
+
},
|
| 681 |
+
)
|
| 682 |
+
DEFAULT = IMAGENET1K_V1
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class EfficientNet_V2_M_Weights(WeightsEnum):
|
| 686 |
+
IMAGENET1K_V1 = Weights(
|
| 687 |
+
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
|
| 688 |
+
transforms=partial(
|
| 689 |
+
ImageClassification,
|
| 690 |
+
crop_size=480,
|
| 691 |
+
resize_size=480,
|
| 692 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 693 |
+
),
|
| 694 |
+
meta={
|
| 695 |
+
**_COMMON_META_V2,
|
| 696 |
+
"num_params": 54139356,
|
| 697 |
+
"_metrics": {
|
| 698 |
+
"ImageNet-1K": {
|
| 699 |
+
"acc@1": 85.112,
|
| 700 |
+
"acc@5": 97.156,
|
| 701 |
+
}
|
| 702 |
+
},
|
| 703 |
+
"_ops": 24.582,
|
| 704 |
+
"_file_size": 208.01,
|
| 705 |
+
"_docs": """
|
| 706 |
+
These weights improve upon the results of the original paper by using a modified version of TorchVision's
|
| 707 |
+
`new training recipe
|
| 708 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 709 |
+
""",
|
| 710 |
+
},
|
| 711 |
+
)
|
| 712 |
+
DEFAULT = IMAGENET1K_V1
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
class EfficientNet_V2_L_Weights(WeightsEnum):
|
| 716 |
+
# Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
|
| 717 |
+
IMAGENET1K_V1 = Weights(
|
| 718 |
+
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
|
| 719 |
+
transforms=partial(
|
| 720 |
+
ImageClassification,
|
| 721 |
+
crop_size=480,
|
| 722 |
+
resize_size=480,
|
| 723 |
+
interpolation=InterpolationMode.BICUBIC,
|
| 724 |
+
mean=(0.5, 0.5, 0.5),
|
| 725 |
+
std=(0.5, 0.5, 0.5),
|
| 726 |
+
),
|
| 727 |
+
meta={
|
| 728 |
+
**_COMMON_META_V2,
|
| 729 |
+
"num_params": 118515272,
|
| 730 |
+
"_metrics": {
|
| 731 |
+
"ImageNet-1K": {
|
| 732 |
+
"acc@1": 85.808,
|
| 733 |
+
"acc@5": 97.788,
|
| 734 |
+
}
|
| 735 |
+
},
|
| 736 |
+
"_ops": 56.08,
|
| 737 |
+
"_file_size": 454.573,
|
| 738 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 739 |
+
},
|
| 740 |
+
)
|
| 741 |
+
DEFAULT = IMAGENET1K_V1
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
@register_model()
|
| 745 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
|
| 746 |
+
def efficientnet_b0(
|
| 747 |
+
*, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
|
| 748 |
+
) -> EfficientNet:
|
| 749 |
+
"""EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 750 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 751 |
+
|
| 752 |
+
Args:
|
| 753 |
+
weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
|
| 754 |
+
pretrained weights to use. See
|
| 755 |
+
:class:`~torchvision.models.EfficientNet_B0_Weights` below for
|
| 756 |
+
more details, and possible values. By default, no pre-trained
|
| 757 |
+
weights are used.
|
| 758 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 759 |
+
download to stderr. Default is True.
|
| 760 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 761 |
+
base class. Please refer to the `source code
|
| 762 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 763 |
+
for more details about this class.
|
| 764 |
+
.. autoclass:: torchvision.models.EfficientNet_B0_Weights
|
| 765 |
+
:members:
|
| 766 |
+
"""
|
| 767 |
+
weights = EfficientNet_B0_Weights.verify(weights)
|
| 768 |
+
|
| 769 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
|
| 770 |
+
return _efficientnet(
|
| 771 |
+
inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
|
| 772 |
+
)
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
@register_model()
|
| 776 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
|
| 777 |
+
def efficientnet_b1(
|
| 778 |
+
*, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
|
| 779 |
+
) -> EfficientNet:
|
| 780 |
+
"""EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 781 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 782 |
+
|
| 783 |
+
Args:
|
| 784 |
+
weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
|
| 785 |
+
pretrained weights to use. See
|
| 786 |
+
:class:`~torchvision.models.EfficientNet_B1_Weights` below for
|
| 787 |
+
more details, and possible values. By default, no pre-trained
|
| 788 |
+
weights are used.
|
| 789 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 790 |
+
download to stderr. Default is True.
|
| 791 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 792 |
+
base class. Please refer to the `source code
|
| 793 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 794 |
+
for more details about this class.
|
| 795 |
+
.. autoclass:: torchvision.models.EfficientNet_B1_Weights
|
| 796 |
+
:members:
|
| 797 |
+
"""
|
| 798 |
+
weights = EfficientNet_B1_Weights.verify(weights)
|
| 799 |
+
|
| 800 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
|
| 801 |
+
return _efficientnet(
|
| 802 |
+
inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
@register_model()
|
| 807 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
|
| 808 |
+
def efficientnet_b2(
|
| 809 |
+
*, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
|
| 810 |
+
) -> EfficientNet:
|
| 811 |
+
"""EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 812 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 813 |
+
|
| 814 |
+
Args:
|
| 815 |
+
weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
|
| 816 |
+
pretrained weights to use. See
|
| 817 |
+
:class:`~torchvision.models.EfficientNet_B2_Weights` below for
|
| 818 |
+
more details, and possible values. By default, no pre-trained
|
| 819 |
+
weights are used.
|
| 820 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 821 |
+
download to stderr. Default is True.
|
| 822 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 823 |
+
base class. Please refer to the `source code
|
| 824 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 825 |
+
for more details about this class.
|
| 826 |
+
.. autoclass:: torchvision.models.EfficientNet_B2_Weights
|
| 827 |
+
:members:
|
| 828 |
+
"""
|
| 829 |
+
weights = EfficientNet_B2_Weights.verify(weights)
|
| 830 |
+
|
| 831 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
|
| 832 |
+
return _efficientnet(
|
| 833 |
+
inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
@register_model()
|
| 838 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
|
| 839 |
+
def efficientnet_b3(
|
| 840 |
+
*, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
|
| 841 |
+
) -> EfficientNet:
|
| 842 |
+
"""EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 843 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 844 |
+
|
| 845 |
+
Args:
|
| 846 |
+
weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
|
| 847 |
+
pretrained weights to use. See
|
| 848 |
+
:class:`~torchvision.models.EfficientNet_B3_Weights` below for
|
| 849 |
+
more details, and possible values. By default, no pre-trained
|
| 850 |
+
weights are used.
|
| 851 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 852 |
+
download to stderr. Default is True.
|
| 853 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 854 |
+
base class. Please refer to the `source code
|
| 855 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 856 |
+
for more details about this class.
|
| 857 |
+
.. autoclass:: torchvision.models.EfficientNet_B3_Weights
|
| 858 |
+
:members:
|
| 859 |
+
"""
|
| 860 |
+
weights = EfficientNet_B3_Weights.verify(weights)
|
| 861 |
+
|
| 862 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
|
| 863 |
+
return _efficientnet(
|
| 864 |
+
inverted_residual_setting,
|
| 865 |
+
kwargs.pop("dropout", 0.3),
|
| 866 |
+
last_channel,
|
| 867 |
+
weights,
|
| 868 |
+
progress,
|
| 869 |
+
**kwargs,
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
@register_model()
|
| 874 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
|
| 875 |
+
def efficientnet_b4(
|
| 876 |
+
*, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
|
| 877 |
+
) -> EfficientNet:
|
| 878 |
+
"""EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 879 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 880 |
+
|
| 881 |
+
Args:
|
| 882 |
+
weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
|
| 883 |
+
pretrained weights to use. See
|
| 884 |
+
:class:`~torchvision.models.EfficientNet_B4_Weights` below for
|
| 885 |
+
more details, and possible values. By default, no pre-trained
|
| 886 |
+
weights are used.
|
| 887 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 888 |
+
download to stderr. Default is True.
|
| 889 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 890 |
+
base class. Please refer to the `source code
|
| 891 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 892 |
+
for more details about this class.
|
| 893 |
+
.. autoclass:: torchvision.models.EfficientNet_B4_Weights
|
| 894 |
+
:members:
|
| 895 |
+
"""
|
| 896 |
+
weights = EfficientNet_B4_Weights.verify(weights)
|
| 897 |
+
|
| 898 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
|
| 899 |
+
return _efficientnet(
|
| 900 |
+
inverted_residual_setting,
|
| 901 |
+
kwargs.pop("dropout", 0.4),
|
| 902 |
+
last_channel,
|
| 903 |
+
weights,
|
| 904 |
+
progress,
|
| 905 |
+
**kwargs,
|
| 906 |
+
)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
@register_model()
|
| 910 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
|
| 911 |
+
def efficientnet_b5(
|
| 912 |
+
*, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
|
| 913 |
+
) -> EfficientNet:
|
| 914 |
+
"""EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 915 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 916 |
+
|
| 917 |
+
Args:
|
| 918 |
+
weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
|
| 919 |
+
pretrained weights to use. See
|
| 920 |
+
:class:`~torchvision.models.EfficientNet_B5_Weights` below for
|
| 921 |
+
more details, and possible values. By default, no pre-trained
|
| 922 |
+
weights are used.
|
| 923 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 924 |
+
download to stderr. Default is True.
|
| 925 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 926 |
+
base class. Please refer to the `source code
|
| 927 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 928 |
+
for more details about this class.
|
| 929 |
+
.. autoclass:: torchvision.models.EfficientNet_B5_Weights
|
| 930 |
+
:members:
|
| 931 |
+
"""
|
| 932 |
+
weights = EfficientNet_B5_Weights.verify(weights)
|
| 933 |
+
|
| 934 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
|
| 935 |
+
return _efficientnet(
|
| 936 |
+
inverted_residual_setting,
|
| 937 |
+
kwargs.pop("dropout", 0.4),
|
| 938 |
+
last_channel,
|
| 939 |
+
weights,
|
| 940 |
+
progress,
|
| 941 |
+
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
|
| 942 |
+
**kwargs,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
|
| 946 |
+
@register_model()
|
| 947 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
|
| 948 |
+
def efficientnet_b6(
|
| 949 |
+
*, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
|
| 950 |
+
) -> EfficientNet:
|
| 951 |
+
"""EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 952 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
|
| 956 |
+
pretrained weights to use. See
|
| 957 |
+
:class:`~torchvision.models.EfficientNet_B6_Weights` below for
|
| 958 |
+
more details, and possible values. By default, no pre-trained
|
| 959 |
+
weights are used.
|
| 960 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 961 |
+
download to stderr. Default is True.
|
| 962 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 963 |
+
base class. Please refer to the `source code
|
| 964 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 965 |
+
for more details about this class.
|
| 966 |
+
.. autoclass:: torchvision.models.EfficientNet_B6_Weights
|
| 967 |
+
:members:
|
| 968 |
+
"""
|
| 969 |
+
weights = EfficientNet_B6_Weights.verify(weights)
|
| 970 |
+
|
| 971 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
|
| 972 |
+
return _efficientnet(
|
| 973 |
+
inverted_residual_setting,
|
| 974 |
+
kwargs.pop("dropout", 0.5),
|
| 975 |
+
last_channel,
|
| 976 |
+
weights,
|
| 977 |
+
progress,
|
| 978 |
+
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
|
| 979 |
+
**kwargs,
|
| 980 |
+
)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
@register_model()
|
| 984 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
|
| 985 |
+
def efficientnet_b7(
|
| 986 |
+
*, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
|
| 987 |
+
) -> EfficientNet:
|
| 988 |
+
"""EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
|
| 989 |
+
Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
|
| 990 |
+
|
| 991 |
+
Args:
|
| 992 |
+
weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
|
| 993 |
+
pretrained weights to use. See
|
| 994 |
+
:class:`~torchvision.models.EfficientNet_B7_Weights` below for
|
| 995 |
+
more details, and possible values. By default, no pre-trained
|
| 996 |
+
weights are used.
|
| 997 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 998 |
+
download to stderr. Default is True.
|
| 999 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 1000 |
+
base class. Please refer to the `source code
|
| 1001 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 1002 |
+
for more details about this class.
|
| 1003 |
+
.. autoclass:: torchvision.models.EfficientNet_B7_Weights
|
| 1004 |
+
:members:
|
| 1005 |
+
"""
|
| 1006 |
+
weights = EfficientNet_B7_Weights.verify(weights)
|
| 1007 |
+
|
| 1008 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
|
| 1009 |
+
return _efficientnet(
|
| 1010 |
+
inverted_residual_setting,
|
| 1011 |
+
kwargs.pop("dropout", 0.5),
|
| 1012 |
+
last_channel,
|
| 1013 |
+
weights,
|
| 1014 |
+
progress,
|
| 1015 |
+
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
|
| 1016 |
+
**kwargs,
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
@register_model()
|
| 1021 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
|
| 1022 |
+
def efficientnet_v2_s(
|
| 1023 |
+
*, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1024 |
+
) -> EfficientNet:
|
| 1025 |
+
"""
|
| 1026 |
+
Constructs an EfficientNetV2-S architecture from
|
| 1027 |
+
`EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
|
| 1028 |
+
|
| 1029 |
+
Args:
|
| 1030 |
+
weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
|
| 1031 |
+
pretrained weights to use. See
|
| 1032 |
+
:class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
|
| 1033 |
+
more details, and possible values. By default, no pre-trained
|
| 1034 |
+
weights are used.
|
| 1035 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 1036 |
+
download to stderr. Default is True.
|
| 1037 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 1038 |
+
base class. Please refer to the `source code
|
| 1039 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 1040 |
+
for more details about this class.
|
| 1041 |
+
.. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
|
| 1042 |
+
:members:
|
| 1043 |
+
"""
|
| 1044 |
+
weights = EfficientNet_V2_S_Weights.verify(weights)
|
| 1045 |
+
|
| 1046 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
|
| 1047 |
+
return _efficientnet(
|
| 1048 |
+
inverted_residual_setting,
|
| 1049 |
+
kwargs.pop("dropout", 0.2),
|
| 1050 |
+
last_channel,
|
| 1051 |
+
weights,
|
| 1052 |
+
progress,
|
| 1053 |
+
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
|
| 1054 |
+
**kwargs,
|
| 1055 |
+
)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
@register_model()
|
| 1059 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
|
| 1060 |
+
def efficientnet_v2_m(
|
| 1061 |
+
*, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1062 |
+
) -> EfficientNet:
|
| 1063 |
+
"""
|
| 1064 |
+
Constructs an EfficientNetV2-M architecture from
|
| 1065 |
+
`EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
|
| 1066 |
+
|
| 1067 |
+
Args:
|
| 1068 |
+
weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
|
| 1069 |
+
pretrained weights to use. See
|
| 1070 |
+
:class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
|
| 1071 |
+
more details, and possible values. By default, no pre-trained
|
| 1072 |
+
weights are used.
|
| 1073 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 1074 |
+
download to stderr. Default is True.
|
| 1075 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 1076 |
+
base class. Please refer to the `source code
|
| 1077 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 1078 |
+
for more details about this class.
|
| 1079 |
+
.. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
|
| 1080 |
+
:members:
|
| 1081 |
+
"""
|
| 1082 |
+
weights = EfficientNet_V2_M_Weights.verify(weights)
|
| 1083 |
+
|
| 1084 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
|
| 1085 |
+
return _efficientnet(
|
| 1086 |
+
inverted_residual_setting,
|
| 1087 |
+
kwargs.pop("dropout", 0.3),
|
| 1088 |
+
last_channel,
|
| 1089 |
+
weights,
|
| 1090 |
+
progress,
|
| 1091 |
+
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
|
| 1092 |
+
**kwargs,
|
| 1093 |
+
)
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
@register_model()
|
| 1097 |
+
@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
|
| 1098 |
+
def efficientnet_v2_l(
|
| 1099 |
+
*, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
|
| 1100 |
+
) -> EfficientNet:
|
| 1101 |
+
"""
|
| 1102 |
+
Constructs an EfficientNetV2-L architecture from
|
| 1103 |
+
`EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
|
| 1104 |
+
|
| 1105 |
+
Args:
|
| 1106 |
+
weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
|
| 1107 |
+
pretrained weights to use. See
|
| 1108 |
+
:class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
|
| 1109 |
+
more details, and possible values. By default, no pre-trained
|
| 1110 |
+
weights are used.
|
| 1111 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 1112 |
+
download to stderr. Default is True.
|
| 1113 |
+
**kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
|
| 1114 |
+
base class. Please refer to the `source code
|
| 1115 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
|
| 1116 |
+
for more details about this class.
|
| 1117 |
+
.. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
|
| 1118 |
+
:members:
|
| 1119 |
+
"""
|
| 1120 |
+
weights = EfficientNet_V2_L_Weights.verify(weights)
|
| 1121 |
+
|
| 1122 |
+
inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
|
| 1123 |
+
return _efficientnet(
|
| 1124 |
+
inverted_residual_setting,
|
| 1125 |
+
kwargs.pop("dropout", 0.4),
|
| 1126 |
+
last_channel,
|
| 1127 |
+
weights,
|
| 1128 |
+
progress,
|
| 1129 |
+
norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
|
| 1130 |
+
**kwargs,
|
| 1131 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/models/googlenet.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
|
| 11 |
+
from ..transforms._presets import ImageClassification
|
| 12 |
+
from ..utils import _log_api_usage_once
|
| 13 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 14 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 15 |
+
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
__all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
|
| 22 |
+
GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
|
| 23 |
+
|
| 24 |
+
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
|
| 25 |
+
# _GoogLeNetOutputs set here for backwards compat
|
| 26 |
+
_GoogLeNetOutputs = GoogLeNetOutputs
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class GoogLeNet(nn.Module):
|
| 30 |
+
__constants__ = ["aux_logits", "transform_input"]
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
num_classes: int = 1000,
|
| 35 |
+
aux_logits: bool = True,
|
| 36 |
+
transform_input: bool = False,
|
| 37 |
+
init_weights: Optional[bool] = None,
|
| 38 |
+
blocks: Optional[List[Callable[..., nn.Module]]] = None,
|
| 39 |
+
dropout: float = 0.2,
|
| 40 |
+
dropout_aux: float = 0.7,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
_log_api_usage_once(self)
|
| 44 |
+
if blocks is None:
|
| 45 |
+
blocks = [BasicConv2d, Inception, InceptionAux]
|
| 46 |
+
if init_weights is None:
|
| 47 |
+
warnings.warn(
|
| 48 |
+
"The default weight initialization of GoogleNet will be changed in future releases of "
|
| 49 |
+
"torchvision. If you wish to keep the old behavior (which leads to long initialization times"
|
| 50 |
+
" due to scipy/scipy#11299), please set init_weights=True.",
|
| 51 |
+
FutureWarning,
|
| 52 |
+
)
|
| 53 |
+
init_weights = True
|
| 54 |
+
if len(blocks) != 3:
|
| 55 |
+
raise ValueError(f"blocks length should be 3 instead of {len(blocks)}")
|
| 56 |
+
conv_block = blocks[0]
|
| 57 |
+
inception_block = blocks[1]
|
| 58 |
+
inception_aux_block = blocks[2]
|
| 59 |
+
|
| 60 |
+
self.aux_logits = aux_logits
|
| 61 |
+
self.transform_input = transform_input
|
| 62 |
+
|
| 63 |
+
self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
|
| 64 |
+
self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
|
| 65 |
+
self.conv2 = conv_block(64, 64, kernel_size=1)
|
| 66 |
+
self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
|
| 67 |
+
self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
|
| 68 |
+
|
| 69 |
+
self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
|
| 70 |
+
self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
|
| 71 |
+
self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
|
| 72 |
+
|
| 73 |
+
self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
|
| 74 |
+
self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
|
| 75 |
+
self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
|
| 76 |
+
self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
|
| 77 |
+
self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
|
| 78 |
+
self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 79 |
+
|
| 80 |
+
self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
|
| 81 |
+
self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
|
| 82 |
+
|
| 83 |
+
if aux_logits:
|
| 84 |
+
self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
|
| 85 |
+
self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
|
| 86 |
+
else:
|
| 87 |
+
self.aux1 = None # type: ignore[assignment]
|
| 88 |
+
self.aux2 = None # type: ignore[assignment]
|
| 89 |
+
|
| 90 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 91 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 92 |
+
self.fc = nn.Linear(1024, num_classes)
|
| 93 |
+
|
| 94 |
+
if init_weights:
|
| 95 |
+
for m in self.modules():
|
| 96 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 97 |
+
torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
|
| 98 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 99 |
+
nn.init.constant_(m.weight, 1)
|
| 100 |
+
nn.init.constant_(m.bias, 0)
|
| 101 |
+
|
| 102 |
+
def _transform_input(self, x: Tensor) -> Tensor:
|
| 103 |
+
if self.transform_input:
|
| 104 |
+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
| 105 |
+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
| 106 |
+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
| 107 |
+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
| 111 |
+
# N x 3 x 224 x 224
|
| 112 |
+
x = self.conv1(x)
|
| 113 |
+
# N x 64 x 112 x 112
|
| 114 |
+
x = self.maxpool1(x)
|
| 115 |
+
# N x 64 x 56 x 56
|
| 116 |
+
x = self.conv2(x)
|
| 117 |
+
# N x 64 x 56 x 56
|
| 118 |
+
x = self.conv3(x)
|
| 119 |
+
# N x 192 x 56 x 56
|
| 120 |
+
x = self.maxpool2(x)
|
| 121 |
+
|
| 122 |
+
# N x 192 x 28 x 28
|
| 123 |
+
x = self.inception3a(x)
|
| 124 |
+
# N x 256 x 28 x 28
|
| 125 |
+
x = self.inception3b(x)
|
| 126 |
+
# N x 480 x 28 x 28
|
| 127 |
+
x = self.maxpool3(x)
|
| 128 |
+
# N x 480 x 14 x 14
|
| 129 |
+
x = self.inception4a(x)
|
| 130 |
+
# N x 512 x 14 x 14
|
| 131 |
+
aux1: Optional[Tensor] = None
|
| 132 |
+
if self.aux1 is not None:
|
| 133 |
+
if self.training:
|
| 134 |
+
aux1 = self.aux1(x)
|
| 135 |
+
|
| 136 |
+
x = self.inception4b(x)
|
| 137 |
+
# N x 512 x 14 x 14
|
| 138 |
+
x = self.inception4c(x)
|
| 139 |
+
# N x 512 x 14 x 14
|
| 140 |
+
x = self.inception4d(x)
|
| 141 |
+
# N x 528 x 14 x 14
|
| 142 |
+
aux2: Optional[Tensor] = None
|
| 143 |
+
if self.aux2 is not None:
|
| 144 |
+
if self.training:
|
| 145 |
+
aux2 = self.aux2(x)
|
| 146 |
+
|
| 147 |
+
x = self.inception4e(x)
|
| 148 |
+
# N x 832 x 14 x 14
|
| 149 |
+
x = self.maxpool4(x)
|
| 150 |
+
# N x 832 x 7 x 7
|
| 151 |
+
x = self.inception5a(x)
|
| 152 |
+
# N x 832 x 7 x 7
|
| 153 |
+
x = self.inception5b(x)
|
| 154 |
+
# N x 1024 x 7 x 7
|
| 155 |
+
|
| 156 |
+
x = self.avgpool(x)
|
| 157 |
+
# N x 1024 x 1 x 1
|
| 158 |
+
x = torch.flatten(x, 1)
|
| 159 |
+
# N x 1024
|
| 160 |
+
x = self.dropout(x)
|
| 161 |
+
x = self.fc(x)
|
| 162 |
+
# N x 1000 (num_classes)
|
| 163 |
+
return x, aux2, aux1
|
| 164 |
+
|
| 165 |
+
@torch.jit.unused
|
| 166 |
+
def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
|
| 167 |
+
if self.training and self.aux_logits:
|
| 168 |
+
return _GoogLeNetOutputs(x, aux2, aux1)
|
| 169 |
+
else:
|
| 170 |
+
return x # type: ignore[return-value]
|
| 171 |
+
|
| 172 |
+
def forward(self, x: Tensor) -> GoogLeNetOutputs:
|
| 173 |
+
x = self._transform_input(x)
|
| 174 |
+
x, aux1, aux2 = self._forward(x)
|
| 175 |
+
aux_defined = self.training and self.aux_logits
|
| 176 |
+
if torch.jit.is_scripting():
|
| 177 |
+
if not aux_defined:
|
| 178 |
+
warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
|
| 179 |
+
return GoogLeNetOutputs(x, aux2, aux1)
|
| 180 |
+
else:
|
| 181 |
+
return self.eager_outputs(x, aux2, aux1)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class Inception(nn.Module):
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
in_channels: int,
|
| 188 |
+
ch1x1: int,
|
| 189 |
+
ch3x3red: int,
|
| 190 |
+
ch3x3: int,
|
| 191 |
+
ch5x5red: int,
|
| 192 |
+
ch5x5: int,
|
| 193 |
+
pool_proj: int,
|
| 194 |
+
conv_block: Optional[Callable[..., nn.Module]] = None,
|
| 195 |
+
) -> None:
|
| 196 |
+
super().__init__()
|
| 197 |
+
if conv_block is None:
|
| 198 |
+
conv_block = BasicConv2d
|
| 199 |
+
self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
|
| 200 |
+
|
| 201 |
+
self.branch2 = nn.Sequential(
|
| 202 |
+
conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self.branch3 = nn.Sequential(
|
| 206 |
+
conv_block(in_channels, ch5x5red, kernel_size=1),
|
| 207 |
+
# Here, kernel_size=3 instead of kernel_size=5 is a known bug.
|
| 208 |
+
# Please see https://github.com/pytorch/vision/issues/906 for details.
|
| 209 |
+
conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.branch4 = nn.Sequential(
|
| 213 |
+
nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
|
| 214 |
+
conv_block(in_channels, pool_proj, kernel_size=1),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def _forward(self, x: Tensor) -> List[Tensor]:
|
| 218 |
+
branch1 = self.branch1(x)
|
| 219 |
+
branch2 = self.branch2(x)
|
| 220 |
+
branch3 = self.branch3(x)
|
| 221 |
+
branch4 = self.branch4(x)
|
| 222 |
+
|
| 223 |
+
outputs = [branch1, branch2, branch3, branch4]
|
| 224 |
+
return outputs
|
| 225 |
+
|
| 226 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 227 |
+
outputs = self._forward(x)
|
| 228 |
+
return torch.cat(outputs, 1)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class InceptionAux(nn.Module):
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
in_channels: int,
|
| 235 |
+
num_classes: int,
|
| 236 |
+
conv_block: Optional[Callable[..., nn.Module]] = None,
|
| 237 |
+
dropout: float = 0.7,
|
| 238 |
+
) -> None:
|
| 239 |
+
super().__init__()
|
| 240 |
+
if conv_block is None:
|
| 241 |
+
conv_block = BasicConv2d
|
| 242 |
+
self.conv = conv_block(in_channels, 128, kernel_size=1)
|
| 243 |
+
|
| 244 |
+
self.fc1 = nn.Linear(2048, 1024)
|
| 245 |
+
self.fc2 = nn.Linear(1024, num_classes)
|
| 246 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 247 |
+
|
| 248 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 249 |
+
# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
|
| 250 |
+
x = F.adaptive_avg_pool2d(x, (4, 4))
|
| 251 |
+
# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
|
| 252 |
+
x = self.conv(x)
|
| 253 |
+
# N x 128 x 4 x 4
|
| 254 |
+
x = torch.flatten(x, 1)
|
| 255 |
+
# N x 2048
|
| 256 |
+
x = F.relu(self.fc1(x), inplace=True)
|
| 257 |
+
# N x 1024
|
| 258 |
+
x = self.dropout(x)
|
| 259 |
+
# N x 1024
|
| 260 |
+
x = self.fc2(x)
|
| 261 |
+
# N x 1000 (num_classes)
|
| 262 |
+
|
| 263 |
+
return x
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class BasicConv2d(nn.Module):
|
| 267 |
+
def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
| 270 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
| 271 |
+
|
| 272 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 273 |
+
x = self.conv(x)
|
| 274 |
+
x = self.bn(x)
|
| 275 |
+
return F.relu(x, inplace=True)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class GoogLeNet_Weights(WeightsEnum):
|
| 279 |
+
IMAGENET1K_V1 = Weights(
|
| 280 |
+
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
|
| 281 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 282 |
+
meta={
|
| 283 |
+
"num_params": 6624904,
|
| 284 |
+
"min_size": (15, 15),
|
| 285 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 286 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
|
| 287 |
+
"_metrics": {
|
| 288 |
+
"ImageNet-1K": {
|
| 289 |
+
"acc@1": 69.778,
|
| 290 |
+
"acc@5": 89.530,
|
| 291 |
+
}
|
| 292 |
+
},
|
| 293 |
+
"_ops": 1.498,
|
| 294 |
+
"_file_size": 49.731,
|
| 295 |
+
"_docs": """These weights are ported from the original paper.""",
|
| 296 |
+
},
|
| 297 |
+
)
|
| 298 |
+
DEFAULT = IMAGENET1K_V1
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@register_model()
|
| 302 |
+
@handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
|
| 303 |
+
def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
|
| 304 |
+
"""GoogLeNet (Inception v1) model architecture from
|
| 305 |
+
`Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`_.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
weights (:class:`~torchvision.models.GoogLeNet_Weights`, optional): The
|
| 309 |
+
pretrained weights for the model. See
|
| 310 |
+
:class:`~torchvision.models.GoogLeNet_Weights` below for
|
| 311 |
+
more details, and possible values. By default, no pre-trained
|
| 312 |
+
weights are used.
|
| 313 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 314 |
+
download to stderr. Default is True.
|
| 315 |
+
**kwargs: parameters passed to the ``torchvision.models.GoogLeNet``
|
| 316 |
+
base class. Please refer to the `source code
|
| 317 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/googlenet.py>`_
|
| 318 |
+
for more details about this class.
|
| 319 |
+
.. autoclass:: torchvision.models.GoogLeNet_Weights
|
| 320 |
+
:members:
|
| 321 |
+
"""
|
| 322 |
+
weights = GoogLeNet_Weights.verify(weights)
|
| 323 |
+
|
| 324 |
+
original_aux_logits = kwargs.get("aux_logits", False)
|
| 325 |
+
if weights is not None:
|
| 326 |
+
if "transform_input" not in kwargs:
|
| 327 |
+
_ovewrite_named_param(kwargs, "transform_input", True)
|
| 328 |
+
_ovewrite_named_param(kwargs, "aux_logits", True)
|
| 329 |
+
_ovewrite_named_param(kwargs, "init_weights", False)
|
| 330 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 331 |
+
|
| 332 |
+
model = GoogLeNet(**kwargs)
|
| 333 |
+
|
| 334 |
+
if weights is not None:
|
| 335 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 336 |
+
if not original_aux_logits:
|
| 337 |
+
model.aux_logits = False
|
| 338 |
+
model.aux1 = None # type: ignore[assignment]
|
| 339 |
+
model.aux2 = None # type: ignore[assignment]
|
| 340 |
+
else:
|
| 341 |
+
warnings.warn(
|
| 342 |
+
"auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
return model
|
.venv/lib/python3.11/site-packages/torchvision/models/maxvit.py
ADDED
|
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import Any, Callable, List, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn, Tensor
|
| 10 |
+
from torchvision.models._api import register_model, Weights, WeightsEnum
|
| 11 |
+
from torchvision.models._meta import _IMAGENET_CATEGORIES
|
| 12 |
+
from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
|
| 13 |
+
from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
|
| 14 |
+
from torchvision.ops.stochastic_depth import StochasticDepth
|
| 15 |
+
from torchvision.transforms._presets import ImageClassification, InterpolationMode
|
| 16 |
+
from torchvision.utils import _log_api_usage_once
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"MaxVit",
|
| 20 |
+
"MaxVit_T_Weights",
|
| 21 |
+
"maxvit_t",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]:
|
| 26 |
+
return (
|
| 27 |
+
(input_size[0] - kernel_size + 2 * padding) // stride + 1,
|
| 28 |
+
(input_size[1] - kernel_size + 2 * padding) // stride + 1,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:
|
| 33 |
+
"""Util function to check that the input size is correct for a MaxVit configuration."""
|
| 34 |
+
shapes = []
|
| 35 |
+
block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)
|
| 36 |
+
for _ in range(n_blocks):
|
| 37 |
+
block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)
|
| 38 |
+
shapes.append(block_input_shape)
|
| 39 |
+
return shapes
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
|
| 43 |
+
coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
|
| 44 |
+
coords_flat = torch.flatten(coords, 1)
|
| 45 |
+
relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
|
| 46 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
| 47 |
+
relative_coords[:, :, 0] += height - 1
|
| 48 |
+
relative_coords[:, :, 1] += width - 1
|
| 49 |
+
relative_coords[:, :, 0] *= 2 * width - 1
|
| 50 |
+
return relative_coords.sum(-1)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MBConv(nn.Module):
|
| 54 |
+
"""MBConv: Mobile Inverted Residual Bottleneck.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
in_channels (int): Number of input channels.
|
| 58 |
+
out_channels (int): Number of output channels.
|
| 59 |
+
expansion_ratio (float): Expansion ratio in the bottleneck.
|
| 60 |
+
squeeze_ratio (float): Squeeze ratio in the SE Layer.
|
| 61 |
+
stride (int): Stride of the depthwise convolution.
|
| 62 |
+
activation_layer (Callable[..., nn.Module]): Activation function.
|
| 63 |
+
norm_layer (Callable[..., nn.Module]): Normalization function.
|
| 64 |
+
p_stochastic_dropout (float): Probability of stochastic depth.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
in_channels: int,
|
| 70 |
+
out_channels: int,
|
| 71 |
+
expansion_ratio: float,
|
| 72 |
+
squeeze_ratio: float,
|
| 73 |
+
stride: int,
|
| 74 |
+
activation_layer: Callable[..., nn.Module],
|
| 75 |
+
norm_layer: Callable[..., nn.Module],
|
| 76 |
+
p_stochastic_dropout: float = 0.0,
|
| 77 |
+
) -> None:
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
proj: Sequence[nn.Module]
|
| 81 |
+
self.proj: nn.Module
|
| 82 |
+
|
| 83 |
+
should_proj = stride != 1 or in_channels != out_channels
|
| 84 |
+
if should_proj:
|
| 85 |
+
proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)]
|
| 86 |
+
if stride == 2:
|
| 87 |
+
proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore
|
| 88 |
+
self.proj = nn.Sequential(*proj)
|
| 89 |
+
else:
|
| 90 |
+
self.proj = nn.Identity() # type: ignore
|
| 91 |
+
|
| 92 |
+
mid_channels = int(out_channels * expansion_ratio)
|
| 93 |
+
sqz_channels = int(out_channels * squeeze_ratio)
|
| 94 |
+
|
| 95 |
+
if p_stochastic_dropout:
|
| 96 |
+
self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore
|
| 97 |
+
else:
|
| 98 |
+
self.stochastic_depth = nn.Identity() # type: ignore
|
| 99 |
+
|
| 100 |
+
_layers = OrderedDict()
|
| 101 |
+
_layers["pre_norm"] = norm_layer(in_channels)
|
| 102 |
+
_layers["conv_a"] = Conv2dNormActivation(
|
| 103 |
+
in_channels,
|
| 104 |
+
mid_channels,
|
| 105 |
+
kernel_size=1,
|
| 106 |
+
stride=1,
|
| 107 |
+
padding=0,
|
| 108 |
+
activation_layer=activation_layer,
|
| 109 |
+
norm_layer=norm_layer,
|
| 110 |
+
inplace=None,
|
| 111 |
+
)
|
| 112 |
+
_layers["conv_b"] = Conv2dNormActivation(
|
| 113 |
+
mid_channels,
|
| 114 |
+
mid_channels,
|
| 115 |
+
kernel_size=3,
|
| 116 |
+
stride=stride,
|
| 117 |
+
padding=1,
|
| 118 |
+
activation_layer=activation_layer,
|
| 119 |
+
norm_layer=norm_layer,
|
| 120 |
+
groups=mid_channels,
|
| 121 |
+
inplace=None,
|
| 122 |
+
)
|
| 123 |
+
_layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU)
|
| 124 |
+
_layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
|
| 125 |
+
|
| 126 |
+
self.layers = nn.Sequential(_layers)
|
| 127 |
+
|
| 128 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 129 |
+
"""
|
| 130 |
+
Args:
|
| 131 |
+
x (Tensor): Input tensor with expected layout of [B, C, H, W].
|
| 132 |
+
Returns:
|
| 133 |
+
Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride].
|
| 134 |
+
"""
|
| 135 |
+
res = self.proj(x)
|
| 136 |
+
x = self.stochastic_depth(self.layers(x))
|
| 137 |
+
return res + x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class RelativePositionalMultiHeadAttention(nn.Module):
|
| 141 |
+
"""Relative Positional Multi-Head Attention.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
feat_dim (int): Number of input features.
|
| 145 |
+
head_dim (int): Number of features per head.
|
| 146 |
+
max_seq_len (int): Maximum sequence length.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
feat_dim: int,
|
| 152 |
+
head_dim: int,
|
| 153 |
+
max_seq_len: int,
|
| 154 |
+
) -> None:
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
if feat_dim % head_dim != 0:
|
| 158 |
+
raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}")
|
| 159 |
+
|
| 160 |
+
self.n_heads = feat_dim // head_dim
|
| 161 |
+
self.head_dim = head_dim
|
| 162 |
+
self.size = int(math.sqrt(max_seq_len))
|
| 163 |
+
self.max_seq_len = max_seq_len
|
| 164 |
+
|
| 165 |
+
self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3)
|
| 166 |
+
self.scale_factor = feat_dim**-0.5
|
| 167 |
+
|
| 168 |
+
self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
|
| 169 |
+
self.relative_position_bias_table = nn.parameter.Parameter(
|
| 170 |
+
torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size))
|
| 174 |
+
# initialize with truncated normal the bias
|
| 175 |
+
torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
|
| 176 |
+
|
| 177 |
+
def get_relative_positional_bias(self) -> torch.Tensor:
|
| 178 |
+
bias_index = self.relative_position_index.view(-1) # type: ignore
|
| 179 |
+
relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore
|
| 180 |
+
relative_bias = relative_bias.permute(2, 0, 1).contiguous()
|
| 181 |
+
return relative_bias.unsqueeze(0)
|
| 182 |
+
|
| 183 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 184 |
+
"""
|
| 185 |
+
Args:
|
| 186 |
+
x (Tensor): Input tensor with expected layout of [B, G, P, D].
|
| 187 |
+
Returns:
|
| 188 |
+
Tensor: Output tensor with expected layout of [B, G, P, D].
|
| 189 |
+
"""
|
| 190 |
+
B, G, P, D = x.shape
|
| 191 |
+
H, DH = self.n_heads, self.head_dim
|
| 192 |
+
|
| 193 |
+
qkv = self.to_qkv(x)
|
| 194 |
+
q, k, v = torch.chunk(qkv, 3, dim=-1)
|
| 195 |
+
|
| 196 |
+
q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
|
| 197 |
+
k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
|
| 198 |
+
v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
|
| 199 |
+
|
| 200 |
+
k = k * self.scale_factor
|
| 201 |
+
dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
|
| 202 |
+
pos_bias = self.get_relative_positional_bias()
|
| 203 |
+
|
| 204 |
+
dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
|
| 205 |
+
|
| 206 |
+
out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v)
|
| 207 |
+
out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D)
|
| 208 |
+
|
| 209 |
+
out = self.merge(out)
|
| 210 |
+
return out
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class SwapAxes(nn.Module):
|
| 214 |
+
"""Permute the axes of a tensor."""
|
| 215 |
+
|
| 216 |
+
def __init__(self, a: int, b: int) -> None:
|
| 217 |
+
super().__init__()
|
| 218 |
+
self.a = a
|
| 219 |
+
self.b = b
|
| 220 |
+
|
| 221 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 222 |
+
res = torch.swapaxes(x, self.a, self.b)
|
| 223 |
+
return res
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class WindowPartition(nn.Module):
|
| 227 |
+
"""
|
| 228 |
+
Partition the input tensor into non-overlapping windows.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self) -> None:
|
| 232 |
+
super().__init__()
|
| 233 |
+
|
| 234 |
+
def forward(self, x: Tensor, p: int) -> Tensor:
|
| 235 |
+
"""
|
| 236 |
+
Args:
|
| 237 |
+
x (Tensor): Input tensor with expected layout of [B, C, H, W].
|
| 238 |
+
p (int): Number of partitions.
|
| 239 |
+
Returns:
|
| 240 |
+
Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C].
|
| 241 |
+
"""
|
| 242 |
+
B, C, H, W = x.shape
|
| 243 |
+
P = p
|
| 244 |
+
# chunk up H and W dimensions
|
| 245 |
+
x = x.reshape(B, C, H // P, P, W // P, P)
|
| 246 |
+
x = x.permute(0, 2, 4, 3, 5, 1)
|
| 247 |
+
# colapse P * P dimension
|
| 248 |
+
x = x.reshape(B, (H // P) * (W // P), P * P, C)
|
| 249 |
+
return x
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class WindowDepartition(nn.Module):
|
| 253 |
+
"""
|
| 254 |
+
Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W].
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def __init__(self) -> None:
|
| 258 |
+
super().__init__()
|
| 259 |
+
|
| 260 |
+
def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor:
|
| 261 |
+
"""
|
| 262 |
+
Args:
|
| 263 |
+
x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C].
|
| 264 |
+
p (int): Number of partitions.
|
| 265 |
+
h_partitions (int): Number of vertical partitions.
|
| 266 |
+
w_partitions (int): Number of horizontal partitions.
|
| 267 |
+
Returns:
|
| 268 |
+
Tensor: Output tensor with expected layout of [B, C, H, W].
|
| 269 |
+
"""
|
| 270 |
+
B, G, PP, C = x.shape
|
| 271 |
+
P = p
|
| 272 |
+
HP, WP = h_partitions, w_partitions
|
| 273 |
+
# split P * P dimension into 2 P tile dimensionsa
|
| 274 |
+
x = x.reshape(B, HP, WP, P, P, C)
|
| 275 |
+
# permute into B, C, HP, P, WP, P
|
| 276 |
+
x = x.permute(0, 5, 1, 3, 2, 4)
|
| 277 |
+
# reshape into B, C, H, W
|
| 278 |
+
x = x.reshape(B, C, HP * P, WP * P)
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class PartitionAttentionLayer(nn.Module):
|
| 283 |
+
"""
|
| 284 |
+
Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
in_channels (int): Number of input channels.
|
| 288 |
+
head_dim (int): Dimension of each attention head.
|
| 289 |
+
partition_size (int): Size of the partitions.
|
| 290 |
+
partition_type (str): Type of partitioning to use. Can be either "grid" or "window".
|
| 291 |
+
grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into.
|
| 292 |
+
mlp_ratio (int): Ratio of the feature size expansion in the MLP layer.
|
| 293 |
+
activation_layer (Callable[..., nn.Module]): Activation function to use.
|
| 294 |
+
norm_layer (Callable[..., nn.Module]): Normalization function to use.
|
| 295 |
+
attention_dropout (float): Dropout probability for the attention layer.
|
| 296 |
+
mlp_dropout (float): Dropout probability for the MLP layer.
|
| 297 |
+
p_stochastic_dropout (float): Probability of dropping out a partition.
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
in_channels: int,
|
| 303 |
+
head_dim: int,
|
| 304 |
+
# partitioning parameters
|
| 305 |
+
partition_size: int,
|
| 306 |
+
partition_type: str,
|
| 307 |
+
# grid size needs to be known at initialization time
|
| 308 |
+
# because we need to know hamy relative offsets there are in the grid
|
| 309 |
+
grid_size: Tuple[int, int],
|
| 310 |
+
mlp_ratio: int,
|
| 311 |
+
activation_layer: Callable[..., nn.Module],
|
| 312 |
+
norm_layer: Callable[..., nn.Module],
|
| 313 |
+
attention_dropout: float,
|
| 314 |
+
mlp_dropout: float,
|
| 315 |
+
p_stochastic_dropout: float,
|
| 316 |
+
) -> None:
|
| 317 |
+
super().__init__()
|
| 318 |
+
|
| 319 |
+
self.n_heads = in_channels // head_dim
|
| 320 |
+
self.head_dim = head_dim
|
| 321 |
+
self.n_partitions = grid_size[0] // partition_size
|
| 322 |
+
self.partition_type = partition_type
|
| 323 |
+
self.grid_size = grid_size
|
| 324 |
+
|
| 325 |
+
if partition_type not in ["grid", "window"]:
|
| 326 |
+
raise ValueError("partition_type must be either 'grid' or 'window'")
|
| 327 |
+
|
| 328 |
+
if partition_type == "window":
|
| 329 |
+
self.p, self.g = partition_size, self.n_partitions
|
| 330 |
+
else:
|
| 331 |
+
self.p, self.g = self.n_partitions, partition_size
|
| 332 |
+
|
| 333 |
+
self.partition_op = WindowPartition()
|
| 334 |
+
self.departition_op = WindowDepartition()
|
| 335 |
+
self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
|
| 336 |
+
self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
|
| 337 |
+
|
| 338 |
+
self.attn_layer = nn.Sequential(
|
| 339 |
+
norm_layer(in_channels),
|
| 340 |
+
# it's always going to be partition_size ** 2 because
|
| 341 |
+
# of the axis swap in the case of grid partitioning
|
| 342 |
+
RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2),
|
| 343 |
+
nn.Dropout(attention_dropout),
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# pre-normalization similar to transformer layers
|
| 347 |
+
self.mlp_layer = nn.Sequential(
|
| 348 |
+
nn.LayerNorm(in_channels),
|
| 349 |
+
nn.Linear(in_channels, in_channels * mlp_ratio),
|
| 350 |
+
activation_layer(),
|
| 351 |
+
nn.Linear(in_channels * mlp_ratio, in_channels),
|
| 352 |
+
nn.Dropout(mlp_dropout),
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# layer scale factors
|
| 356 |
+
self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row")
|
| 357 |
+
|
| 358 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 359 |
+
"""
|
| 360 |
+
Args:
|
| 361 |
+
x (Tensor): Input tensor with expected layout of [B, C, H, W].
|
| 362 |
+
Returns:
|
| 363 |
+
Tensor: Output tensor with expected layout of [B, C, H, W].
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Undefined behavior if H or W are not divisible by p
|
| 367 |
+
# https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
|
| 368 |
+
gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p
|
| 369 |
+
torch._assert(
|
| 370 |
+
self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,
|
| 371 |
+
"Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format(
|
| 372 |
+
self.grid_size, self.p
|
| 373 |
+
),
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
x = self.partition_op(x, self.p)
|
| 377 |
+
x = self.partition_swap(x)
|
| 378 |
+
x = x + self.stochastic_dropout(self.attn_layer(x))
|
| 379 |
+
x = x + self.stochastic_dropout(self.mlp_layer(x))
|
| 380 |
+
x = self.departition_swap(x)
|
| 381 |
+
x = self.departition_op(x, self.p, gh, gw)
|
| 382 |
+
|
| 383 |
+
return x
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class MaxVitLayer(nn.Module):
|
| 387 |
+
"""
|
| 388 |
+
MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`.
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
in_channels (int): Number of input channels.
|
| 392 |
+
out_channels (int): Number of output channels.
|
| 393 |
+
expansion_ratio (float): Expansion ratio in the bottleneck.
|
| 394 |
+
squeeze_ratio (float): Squeeze ratio in the SE Layer.
|
| 395 |
+
stride (int): Stride of the depthwise convolution.
|
| 396 |
+
activation_layer (Callable[..., nn.Module]): Activation function.
|
| 397 |
+
norm_layer (Callable[..., nn.Module]): Normalization function.
|
| 398 |
+
head_dim (int): Dimension of the attention heads.
|
| 399 |
+
mlp_ratio (int): Ratio of the MLP layer.
|
| 400 |
+
mlp_dropout (float): Dropout probability for the MLP layer.
|
| 401 |
+
attention_dropout (float): Dropout probability for the attention layer.
|
| 402 |
+
p_stochastic_dropout (float): Probability of stochastic depth.
|
| 403 |
+
partition_size (int): Size of the partitions.
|
| 404 |
+
grid_size (Tuple[int, int]): Size of the input feature grid.
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
def __init__(
|
| 408 |
+
self,
|
| 409 |
+
# conv parameters
|
| 410 |
+
in_channels: int,
|
| 411 |
+
out_channels: int,
|
| 412 |
+
squeeze_ratio: float,
|
| 413 |
+
expansion_ratio: float,
|
| 414 |
+
stride: int,
|
| 415 |
+
# conv + transformer parameters
|
| 416 |
+
norm_layer: Callable[..., nn.Module],
|
| 417 |
+
activation_layer: Callable[..., nn.Module],
|
| 418 |
+
# transformer parameters
|
| 419 |
+
head_dim: int,
|
| 420 |
+
mlp_ratio: int,
|
| 421 |
+
mlp_dropout: float,
|
| 422 |
+
attention_dropout: float,
|
| 423 |
+
p_stochastic_dropout: float,
|
| 424 |
+
# partitioning parameters
|
| 425 |
+
partition_size: int,
|
| 426 |
+
grid_size: Tuple[int, int],
|
| 427 |
+
) -> None:
|
| 428 |
+
super().__init__()
|
| 429 |
+
|
| 430 |
+
layers: OrderedDict = OrderedDict()
|
| 431 |
+
|
| 432 |
+
# convolutional layer
|
| 433 |
+
layers["MBconv"] = MBConv(
|
| 434 |
+
in_channels=in_channels,
|
| 435 |
+
out_channels=out_channels,
|
| 436 |
+
expansion_ratio=expansion_ratio,
|
| 437 |
+
squeeze_ratio=squeeze_ratio,
|
| 438 |
+
stride=stride,
|
| 439 |
+
activation_layer=activation_layer,
|
| 440 |
+
norm_layer=norm_layer,
|
| 441 |
+
p_stochastic_dropout=p_stochastic_dropout,
|
| 442 |
+
)
|
| 443 |
+
# attention layers, block -> grid
|
| 444 |
+
layers["window_attention"] = PartitionAttentionLayer(
|
| 445 |
+
in_channels=out_channels,
|
| 446 |
+
head_dim=head_dim,
|
| 447 |
+
partition_size=partition_size,
|
| 448 |
+
partition_type="window",
|
| 449 |
+
grid_size=grid_size,
|
| 450 |
+
mlp_ratio=mlp_ratio,
|
| 451 |
+
activation_layer=activation_layer,
|
| 452 |
+
norm_layer=nn.LayerNorm,
|
| 453 |
+
attention_dropout=attention_dropout,
|
| 454 |
+
mlp_dropout=mlp_dropout,
|
| 455 |
+
p_stochastic_dropout=p_stochastic_dropout,
|
| 456 |
+
)
|
| 457 |
+
layers["grid_attention"] = PartitionAttentionLayer(
|
| 458 |
+
in_channels=out_channels,
|
| 459 |
+
head_dim=head_dim,
|
| 460 |
+
partition_size=partition_size,
|
| 461 |
+
partition_type="grid",
|
| 462 |
+
grid_size=grid_size,
|
| 463 |
+
mlp_ratio=mlp_ratio,
|
| 464 |
+
activation_layer=activation_layer,
|
| 465 |
+
norm_layer=nn.LayerNorm,
|
| 466 |
+
attention_dropout=attention_dropout,
|
| 467 |
+
mlp_dropout=mlp_dropout,
|
| 468 |
+
p_stochastic_dropout=p_stochastic_dropout,
|
| 469 |
+
)
|
| 470 |
+
self.layers = nn.Sequential(layers)
|
| 471 |
+
|
| 472 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 473 |
+
"""
|
| 474 |
+
Args:
|
| 475 |
+
x (Tensor): Input tensor of shape (B, C, H, W).
|
| 476 |
+
Returns:
|
| 477 |
+
Tensor: Output tensor of shape (B, C, H, W).
|
| 478 |
+
"""
|
| 479 |
+
x = self.layers(x)
|
| 480 |
+
return x
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class MaxVitBlock(nn.Module):
|
| 484 |
+
"""
|
| 485 |
+
A MaxVit block consisting of `n_layers` MaxVit layers.
|
| 486 |
+
|
| 487 |
+
Args:
|
| 488 |
+
in_channels (int): Number of input channels.
|
| 489 |
+
out_channels (int): Number of output channels.
|
| 490 |
+
expansion_ratio (float): Expansion ratio in the bottleneck.
|
| 491 |
+
squeeze_ratio (float): Squeeze ratio in the SE Layer.
|
| 492 |
+
activation_layer (Callable[..., nn.Module]): Activation function.
|
| 493 |
+
norm_layer (Callable[..., nn.Module]): Normalization function.
|
| 494 |
+
head_dim (int): Dimension of the attention heads.
|
| 495 |
+
mlp_ratio (int): Ratio of the MLP layer.
|
| 496 |
+
mlp_dropout (float): Dropout probability for the MLP layer.
|
| 497 |
+
attention_dropout (float): Dropout probability for the attention layer.
|
| 498 |
+
p_stochastic_dropout (float): Probability of stochastic depth.
|
| 499 |
+
partition_size (int): Size of the partitions.
|
| 500 |
+
input_grid_size (Tuple[int, int]): Size of the input feature grid.
|
| 501 |
+
n_layers (int): Number of layers in the block.
|
| 502 |
+
p_stochastic (List[float]): List of probabilities for stochastic depth for each layer.
|
| 503 |
+
"""
|
| 504 |
+
|
| 505 |
+
def __init__(
|
| 506 |
+
self,
|
| 507 |
+
# conv parameters
|
| 508 |
+
in_channels: int,
|
| 509 |
+
out_channels: int,
|
| 510 |
+
squeeze_ratio: float,
|
| 511 |
+
expansion_ratio: float,
|
| 512 |
+
# conv + transformer parameters
|
| 513 |
+
norm_layer: Callable[..., nn.Module],
|
| 514 |
+
activation_layer: Callable[..., nn.Module],
|
| 515 |
+
# transformer parameters
|
| 516 |
+
head_dim: int,
|
| 517 |
+
mlp_ratio: int,
|
| 518 |
+
mlp_dropout: float,
|
| 519 |
+
attention_dropout: float,
|
| 520 |
+
# partitioning parameters
|
| 521 |
+
partition_size: int,
|
| 522 |
+
input_grid_size: Tuple[int, int],
|
| 523 |
+
# number of layers
|
| 524 |
+
n_layers: int,
|
| 525 |
+
p_stochastic: List[float],
|
| 526 |
+
) -> None:
|
| 527 |
+
super().__init__()
|
| 528 |
+
if not len(p_stochastic) == n_layers:
|
| 529 |
+
raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
|
| 530 |
+
|
| 531 |
+
self.layers = nn.ModuleList()
|
| 532 |
+
# account for the first stride of the first layer
|
| 533 |
+
self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)
|
| 534 |
+
|
| 535 |
+
for idx, p in enumerate(p_stochastic):
|
| 536 |
+
stride = 2 if idx == 0 else 1
|
| 537 |
+
self.layers += [
|
| 538 |
+
MaxVitLayer(
|
| 539 |
+
in_channels=in_channels if idx == 0 else out_channels,
|
| 540 |
+
out_channels=out_channels,
|
| 541 |
+
squeeze_ratio=squeeze_ratio,
|
| 542 |
+
expansion_ratio=expansion_ratio,
|
| 543 |
+
stride=stride,
|
| 544 |
+
norm_layer=norm_layer,
|
| 545 |
+
activation_layer=activation_layer,
|
| 546 |
+
head_dim=head_dim,
|
| 547 |
+
mlp_ratio=mlp_ratio,
|
| 548 |
+
mlp_dropout=mlp_dropout,
|
| 549 |
+
attention_dropout=attention_dropout,
|
| 550 |
+
partition_size=partition_size,
|
| 551 |
+
grid_size=self.grid_size,
|
| 552 |
+
p_stochastic_dropout=p,
|
| 553 |
+
),
|
| 554 |
+
]
|
| 555 |
+
|
| 556 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 557 |
+
"""
|
| 558 |
+
Args:
|
| 559 |
+
x (Tensor): Input tensor of shape (B, C, H, W).
|
| 560 |
+
Returns:
|
| 561 |
+
Tensor: Output tensor of shape (B, C, H, W).
|
| 562 |
+
"""
|
| 563 |
+
for layer in self.layers:
|
| 564 |
+
x = layer(x)
|
| 565 |
+
return x
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class MaxVit(nn.Module):
|
| 569 |
+
"""
|
| 570 |
+
Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_ paper.
|
| 571 |
+
Args:
|
| 572 |
+
input_size (Tuple[int, int]): Size of the input image.
|
| 573 |
+
stem_channels (int): Number of channels in the stem.
|
| 574 |
+
partition_size (int): Size of the partitions.
|
| 575 |
+
block_channels (List[int]): Number of channels in each block.
|
| 576 |
+
block_layers (List[int]): Number of layers in each block.
|
| 577 |
+
stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
|
| 578 |
+
squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
|
| 579 |
+
expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
|
| 580 |
+
norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.01)`).
|
| 581 |
+
activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
|
| 582 |
+
head_dim (int): Dimension of the attention heads.
|
| 583 |
+
mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
|
| 584 |
+
mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0.
|
| 585 |
+
attention_dropout (float): Dropout probability for the attention layer. Default: 0.0.
|
| 586 |
+
num_classes (int): Number of classes. Default: 1000.
|
| 587 |
+
"""
|
| 588 |
+
|
| 589 |
+
def __init__(
|
| 590 |
+
self,
|
| 591 |
+
# input size parameters
|
| 592 |
+
input_size: Tuple[int, int],
|
| 593 |
+
# stem and task parameters
|
| 594 |
+
stem_channels: int,
|
| 595 |
+
# partitioning parameters
|
| 596 |
+
partition_size: int,
|
| 597 |
+
# block parameters
|
| 598 |
+
block_channels: List[int],
|
| 599 |
+
block_layers: List[int],
|
| 600 |
+
# attention head dimensions
|
| 601 |
+
head_dim: int,
|
| 602 |
+
stochastic_depth_prob: float,
|
| 603 |
+
# conv + transformer parameters
|
| 604 |
+
# norm_layer is applied only to the conv layers
|
| 605 |
+
# activation_layer is applied both to conv and transformer layers
|
| 606 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 607 |
+
activation_layer: Callable[..., nn.Module] = nn.GELU,
|
| 608 |
+
# conv parameters
|
| 609 |
+
squeeze_ratio: float = 0.25,
|
| 610 |
+
expansion_ratio: float = 4,
|
| 611 |
+
# transformer parameters
|
| 612 |
+
mlp_ratio: int = 4,
|
| 613 |
+
mlp_dropout: float = 0.0,
|
| 614 |
+
attention_dropout: float = 0.0,
|
| 615 |
+
# task parameters
|
| 616 |
+
num_classes: int = 1000,
|
| 617 |
+
) -> None:
|
| 618 |
+
super().__init__()
|
| 619 |
+
_log_api_usage_once(self)
|
| 620 |
+
|
| 621 |
+
input_channels = 3
|
| 622 |
+
|
| 623 |
+
# https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
|
| 624 |
+
# for the exact parameters used in batchnorm
|
| 625 |
+
if norm_layer is None:
|
| 626 |
+
norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)
|
| 627 |
+
|
| 628 |
+
# Make sure input size will be divisible by the partition size in all blocks
|
| 629 |
+
# Undefined behavior if H or W are not divisible by p
|
| 630 |
+
# https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
|
| 631 |
+
block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))
|
| 632 |
+
for idx, block_input_size in enumerate(block_input_sizes):
|
| 633 |
+
if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:
|
| 634 |
+
raise ValueError(
|
| 635 |
+
f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. "
|
| 636 |
+
f"Consider changing the partition size or the input size.\n"
|
| 637 |
+
f"Current configuration yields the following block input sizes: {block_input_sizes}."
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# stem
|
| 641 |
+
self.stem = nn.Sequential(
|
| 642 |
+
Conv2dNormActivation(
|
| 643 |
+
input_channels,
|
| 644 |
+
stem_channels,
|
| 645 |
+
3,
|
| 646 |
+
stride=2,
|
| 647 |
+
norm_layer=norm_layer,
|
| 648 |
+
activation_layer=activation_layer,
|
| 649 |
+
bias=False,
|
| 650 |
+
inplace=None,
|
| 651 |
+
),
|
| 652 |
+
Conv2dNormActivation(
|
| 653 |
+
stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
|
| 654 |
+
),
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# account for stem stride
|
| 658 |
+
input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)
|
| 659 |
+
self.partition_size = partition_size
|
| 660 |
+
|
| 661 |
+
# blocks
|
| 662 |
+
self.blocks = nn.ModuleList()
|
| 663 |
+
in_channels = [stem_channels] + block_channels[:-1]
|
| 664 |
+
out_channels = block_channels
|
| 665 |
+
|
| 666 |
+
# precompute the stochastich depth probabilities from 0 to stochastic_depth_prob
|
| 667 |
+
# since we have N blocks with L layers, we will have N * L probabilities uniformly distributed
|
| 668 |
+
# over the range [0, stochastic_depth_prob]
|
| 669 |
+
p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()
|
| 670 |
+
|
| 671 |
+
p_idx = 0
|
| 672 |
+
for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):
|
| 673 |
+
self.blocks.append(
|
| 674 |
+
MaxVitBlock(
|
| 675 |
+
in_channels=in_channel,
|
| 676 |
+
out_channels=out_channel,
|
| 677 |
+
squeeze_ratio=squeeze_ratio,
|
| 678 |
+
expansion_ratio=expansion_ratio,
|
| 679 |
+
norm_layer=norm_layer,
|
| 680 |
+
activation_layer=activation_layer,
|
| 681 |
+
head_dim=head_dim,
|
| 682 |
+
mlp_ratio=mlp_ratio,
|
| 683 |
+
mlp_dropout=mlp_dropout,
|
| 684 |
+
attention_dropout=attention_dropout,
|
| 685 |
+
partition_size=partition_size,
|
| 686 |
+
input_grid_size=input_size,
|
| 687 |
+
n_layers=num_layers,
|
| 688 |
+
p_stochastic=p_stochastic[p_idx : p_idx + num_layers],
|
| 689 |
+
),
|
| 690 |
+
)
|
| 691 |
+
input_size = self.blocks[-1].grid_size
|
| 692 |
+
p_idx += num_layers
|
| 693 |
+
|
| 694 |
+
# see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158
|
| 695 |
+
# for why there is Linear -> Tanh -> Linear
|
| 696 |
+
self.classifier = nn.Sequential(
|
| 697 |
+
nn.AdaptiveAvgPool2d(1),
|
| 698 |
+
nn.Flatten(),
|
| 699 |
+
nn.LayerNorm(block_channels[-1]),
|
| 700 |
+
nn.Linear(block_channels[-1], block_channels[-1]),
|
| 701 |
+
nn.Tanh(),
|
| 702 |
+
nn.Linear(block_channels[-1], num_classes, bias=False),
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
self._init_weights()
|
| 706 |
+
|
| 707 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 708 |
+
x = self.stem(x)
|
| 709 |
+
for block in self.blocks:
|
| 710 |
+
x = block(x)
|
| 711 |
+
x = self.classifier(x)
|
| 712 |
+
return x
|
| 713 |
+
|
| 714 |
+
def _init_weights(self):
|
| 715 |
+
for m in self.modules():
|
| 716 |
+
if isinstance(m, nn.Conv2d):
|
| 717 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 718 |
+
if m.bias is not None:
|
| 719 |
+
nn.init.zeros_(m.bias)
|
| 720 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 721 |
+
nn.init.constant_(m.weight, 1)
|
| 722 |
+
nn.init.constant_(m.bias, 0)
|
| 723 |
+
elif isinstance(m, nn.Linear):
|
| 724 |
+
nn.init.normal_(m.weight, std=0.02)
|
| 725 |
+
if m.bias is not None:
|
| 726 |
+
nn.init.zeros_(m.bias)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def _maxvit(
|
| 730 |
+
# stem parameters
|
| 731 |
+
stem_channels: int,
|
| 732 |
+
# block parameters
|
| 733 |
+
block_channels: List[int],
|
| 734 |
+
block_layers: List[int],
|
| 735 |
+
stochastic_depth_prob: float,
|
| 736 |
+
# partitioning parameters
|
| 737 |
+
partition_size: int,
|
| 738 |
+
# transformer parameters
|
| 739 |
+
head_dim: int,
|
| 740 |
+
# Weights API
|
| 741 |
+
weights: Optional[WeightsEnum] = None,
|
| 742 |
+
progress: bool = False,
|
| 743 |
+
# kwargs,
|
| 744 |
+
**kwargs: Any,
|
| 745 |
+
) -> MaxVit:
|
| 746 |
+
|
| 747 |
+
if weights is not None:
|
| 748 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 749 |
+
assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
|
| 750 |
+
_ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])
|
| 751 |
+
|
| 752 |
+
input_size = kwargs.pop("input_size", (224, 224))
|
| 753 |
+
|
| 754 |
+
model = MaxVit(
|
| 755 |
+
stem_channels=stem_channels,
|
| 756 |
+
block_channels=block_channels,
|
| 757 |
+
block_layers=block_layers,
|
| 758 |
+
stochastic_depth_prob=stochastic_depth_prob,
|
| 759 |
+
head_dim=head_dim,
|
| 760 |
+
partition_size=partition_size,
|
| 761 |
+
input_size=input_size,
|
| 762 |
+
**kwargs,
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
if weights is not None:
|
| 766 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 767 |
+
|
| 768 |
+
return model
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
class MaxVit_T_Weights(WeightsEnum):
|
| 772 |
+
IMAGENET1K_V1 = Weights(
|
| 773 |
+
# URL empty until official release
|
| 774 |
+
url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth",
|
| 775 |
+
transforms=partial(
|
| 776 |
+
ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
|
| 777 |
+
),
|
| 778 |
+
meta={
|
| 779 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 780 |
+
"num_params": 30919624,
|
| 781 |
+
"min_size": (224, 224),
|
| 782 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit",
|
| 783 |
+
"_metrics": {
|
| 784 |
+
"ImageNet-1K": {
|
| 785 |
+
"acc@1": 83.700,
|
| 786 |
+
"acc@5": 96.722,
|
| 787 |
+
}
|
| 788 |
+
},
|
| 789 |
+
"_ops": 5.558,
|
| 790 |
+
"_file_size": 118.769,
|
| 791 |
+
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.
|
| 792 |
+
They were trained with a BatchNorm2D momentum of 0.99 instead of the more correct 0.01.""",
|
| 793 |
+
},
|
| 794 |
+
)
|
| 795 |
+
DEFAULT = IMAGENET1K_V1
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
@register_model()
|
| 799 |
+
@handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
|
| 800 |
+
def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
|
| 801 |
+
"""
|
| 802 |
+
Constructs a maxvit_t architecture from
|
| 803 |
+
`MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
|
| 804 |
+
|
| 805 |
+
Args:
|
| 806 |
+
weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The
|
| 807 |
+
pretrained weights to use. See
|
| 808 |
+
:class:`~torchvision.models.MaxVit_T_Weights` below for
|
| 809 |
+
more details, and possible values. By default, no pre-trained
|
| 810 |
+
weights are used.
|
| 811 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 812 |
+
download to stderr. Default is True.
|
| 813 |
+
**kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit``
|
| 814 |
+
base class. Please refer to the `source code
|
| 815 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_
|
| 816 |
+
for more details about this class.
|
| 817 |
+
|
| 818 |
+
.. autoclass:: torchvision.models.MaxVit_T_Weights
|
| 819 |
+
:members:
|
| 820 |
+
"""
|
| 821 |
+
weights = MaxVit_T_Weights.verify(weights)
|
| 822 |
+
|
| 823 |
+
return _maxvit(
|
| 824 |
+
stem_channels=64,
|
| 825 |
+
block_channels=[64, 128, 256, 512],
|
| 826 |
+
block_layers=[2, 2, 5, 2],
|
| 827 |
+
head_dim=32,
|
| 828 |
+
stochastic_depth_prob=0.2,
|
| 829 |
+
partition_size=7,
|
| 830 |
+
weights=weights,
|
| 831 |
+
progress=progress,
|
| 832 |
+
**kwargs,
|
| 833 |
+
)
|
.venv/lib/python3.11/site-packages/torchvision/models/mobilenetv3.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Callable, List, Optional, Sequence
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn, Tensor
|
| 6 |
+
|
| 7 |
+
from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
|
| 8 |
+
from ..transforms._presets import ImageClassification
|
| 9 |
+
from ..utils import _log_api_usage_once
|
| 10 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 11 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 12 |
+
from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"MobileNetV3",
|
| 17 |
+
"MobileNet_V3_Large_Weights",
|
| 18 |
+
"MobileNet_V3_Small_Weights",
|
| 19 |
+
"mobilenet_v3_large",
|
| 20 |
+
"mobilenet_v3_small",
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InvertedResidualConfig:
|
| 25 |
+
# Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
input_channels: int,
|
| 29 |
+
kernel: int,
|
| 30 |
+
expanded_channels: int,
|
| 31 |
+
out_channels: int,
|
| 32 |
+
use_se: bool,
|
| 33 |
+
activation: str,
|
| 34 |
+
stride: int,
|
| 35 |
+
dilation: int,
|
| 36 |
+
width_mult: float,
|
| 37 |
+
):
|
| 38 |
+
self.input_channels = self.adjust_channels(input_channels, width_mult)
|
| 39 |
+
self.kernel = kernel
|
| 40 |
+
self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
|
| 41 |
+
self.out_channels = self.adjust_channels(out_channels, width_mult)
|
| 42 |
+
self.use_se = use_se
|
| 43 |
+
self.use_hs = activation == "HS"
|
| 44 |
+
self.stride = stride
|
| 45 |
+
self.dilation = dilation
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def adjust_channels(channels: int, width_mult: float):
|
| 49 |
+
return _make_divisible(channels * width_mult, 8)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class InvertedResidual(nn.Module):
|
| 53 |
+
# Implemented as described at section 5 of MobileNetV3 paper
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
cnf: InvertedResidualConfig,
|
| 57 |
+
norm_layer: Callable[..., nn.Module],
|
| 58 |
+
se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
if not (1 <= cnf.stride <= 2):
|
| 62 |
+
raise ValueError("illegal stride value")
|
| 63 |
+
|
| 64 |
+
self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
|
| 65 |
+
|
| 66 |
+
layers: List[nn.Module] = []
|
| 67 |
+
activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
|
| 68 |
+
|
| 69 |
+
# expand
|
| 70 |
+
if cnf.expanded_channels != cnf.input_channels:
|
| 71 |
+
layers.append(
|
| 72 |
+
Conv2dNormActivation(
|
| 73 |
+
cnf.input_channels,
|
| 74 |
+
cnf.expanded_channels,
|
| 75 |
+
kernel_size=1,
|
| 76 |
+
norm_layer=norm_layer,
|
| 77 |
+
activation_layer=activation_layer,
|
| 78 |
+
)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# depthwise
|
| 82 |
+
stride = 1 if cnf.dilation > 1 else cnf.stride
|
| 83 |
+
layers.append(
|
| 84 |
+
Conv2dNormActivation(
|
| 85 |
+
cnf.expanded_channels,
|
| 86 |
+
cnf.expanded_channels,
|
| 87 |
+
kernel_size=cnf.kernel,
|
| 88 |
+
stride=stride,
|
| 89 |
+
dilation=cnf.dilation,
|
| 90 |
+
groups=cnf.expanded_channels,
|
| 91 |
+
norm_layer=norm_layer,
|
| 92 |
+
activation_layer=activation_layer,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
if cnf.use_se:
|
| 96 |
+
squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
|
| 97 |
+
layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
|
| 98 |
+
|
| 99 |
+
# project
|
| 100 |
+
layers.append(
|
| 101 |
+
Conv2dNormActivation(
|
| 102 |
+
cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
|
| 103 |
+
)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
self.block = nn.Sequential(*layers)
|
| 107 |
+
self.out_channels = cnf.out_channels
|
| 108 |
+
self._is_cn = cnf.stride > 1
|
| 109 |
+
|
| 110 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 111 |
+
result = self.block(input)
|
| 112 |
+
if self.use_res_connect:
|
| 113 |
+
result += input
|
| 114 |
+
return result
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class MobileNetV3(nn.Module):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
inverted_residual_setting: List[InvertedResidualConfig],
|
| 121 |
+
last_channel: int,
|
| 122 |
+
num_classes: int = 1000,
|
| 123 |
+
block: Optional[Callable[..., nn.Module]] = None,
|
| 124 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
| 125 |
+
dropout: float = 0.2,
|
| 126 |
+
**kwargs: Any,
|
| 127 |
+
) -> None:
|
| 128 |
+
"""
|
| 129 |
+
MobileNet V3 main class
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
inverted_residual_setting (List[InvertedResidualConfig]): Network structure
|
| 133 |
+
last_channel (int): The number of channels on the penultimate layer
|
| 134 |
+
num_classes (int): Number of classes
|
| 135 |
+
block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
|
| 136 |
+
norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
|
| 137 |
+
dropout (float): The droupout probability
|
| 138 |
+
"""
|
| 139 |
+
super().__init__()
|
| 140 |
+
_log_api_usage_once(self)
|
| 141 |
+
|
| 142 |
+
if not inverted_residual_setting:
|
| 143 |
+
raise ValueError("The inverted_residual_setting should not be empty")
|
| 144 |
+
elif not (
|
| 145 |
+
isinstance(inverted_residual_setting, Sequence)
|
| 146 |
+
and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
|
| 147 |
+
):
|
| 148 |
+
raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
|
| 149 |
+
|
| 150 |
+
if block is None:
|
| 151 |
+
block = InvertedResidual
|
| 152 |
+
|
| 153 |
+
if norm_layer is None:
|
| 154 |
+
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
|
| 155 |
+
|
| 156 |
+
layers: List[nn.Module] = []
|
| 157 |
+
|
| 158 |
+
# building first layer
|
| 159 |
+
firstconv_output_channels = inverted_residual_setting[0].input_channels
|
| 160 |
+
layers.append(
|
| 161 |
+
Conv2dNormActivation(
|
| 162 |
+
3,
|
| 163 |
+
firstconv_output_channels,
|
| 164 |
+
kernel_size=3,
|
| 165 |
+
stride=2,
|
| 166 |
+
norm_layer=norm_layer,
|
| 167 |
+
activation_layer=nn.Hardswish,
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# building inverted residual blocks
|
| 172 |
+
for cnf in inverted_residual_setting:
|
| 173 |
+
layers.append(block(cnf, norm_layer))
|
| 174 |
+
|
| 175 |
+
# building last several layers
|
| 176 |
+
lastconv_input_channels = inverted_residual_setting[-1].out_channels
|
| 177 |
+
lastconv_output_channels = 6 * lastconv_input_channels
|
| 178 |
+
layers.append(
|
| 179 |
+
Conv2dNormActivation(
|
| 180 |
+
lastconv_input_channels,
|
| 181 |
+
lastconv_output_channels,
|
| 182 |
+
kernel_size=1,
|
| 183 |
+
norm_layer=norm_layer,
|
| 184 |
+
activation_layer=nn.Hardswish,
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
self.features = nn.Sequential(*layers)
|
| 189 |
+
self.avgpool = nn.AdaptiveAvgPool2d(1)
|
| 190 |
+
self.classifier = nn.Sequential(
|
| 191 |
+
nn.Linear(lastconv_output_channels, last_channel),
|
| 192 |
+
nn.Hardswish(inplace=True),
|
| 193 |
+
nn.Dropout(p=dropout, inplace=True),
|
| 194 |
+
nn.Linear(last_channel, num_classes),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
for m in self.modules():
|
| 198 |
+
if isinstance(m, nn.Conv2d):
|
| 199 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
| 200 |
+
if m.bias is not None:
|
| 201 |
+
nn.init.zeros_(m.bias)
|
| 202 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 203 |
+
nn.init.ones_(m.weight)
|
| 204 |
+
nn.init.zeros_(m.bias)
|
| 205 |
+
elif isinstance(m, nn.Linear):
|
| 206 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 207 |
+
nn.init.zeros_(m.bias)
|
| 208 |
+
|
| 209 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 210 |
+
x = self.features(x)
|
| 211 |
+
|
| 212 |
+
x = self.avgpool(x)
|
| 213 |
+
x = torch.flatten(x, 1)
|
| 214 |
+
|
| 215 |
+
x = self.classifier(x)
|
| 216 |
+
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 220 |
+
return self._forward_impl(x)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _mobilenet_v3_conf(
|
| 224 |
+
arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
|
| 225 |
+
):
|
| 226 |
+
reduce_divider = 2 if reduced_tail else 1
|
| 227 |
+
dilation = 2 if dilated else 1
|
| 228 |
+
|
| 229 |
+
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
|
| 230 |
+
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
|
| 231 |
+
|
| 232 |
+
if arch == "mobilenet_v3_large":
|
| 233 |
+
inverted_residual_setting = [
|
| 234 |
+
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
|
| 235 |
+
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
|
| 236 |
+
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
|
| 237 |
+
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
|
| 238 |
+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
| 239 |
+
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
|
| 240 |
+
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
|
| 241 |
+
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
|
| 242 |
+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
| 243 |
+
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
|
| 244 |
+
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
|
| 245 |
+
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
|
| 246 |
+
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
|
| 247 |
+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
| 248 |
+
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
|
| 249 |
+
]
|
| 250 |
+
last_channel = adjust_channels(1280 // reduce_divider) # C5
|
| 251 |
+
elif arch == "mobilenet_v3_small":
|
| 252 |
+
inverted_residual_setting = [
|
| 253 |
+
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
|
| 254 |
+
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
|
| 255 |
+
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
|
| 256 |
+
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
|
| 257 |
+
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
|
| 258 |
+
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
|
| 259 |
+
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
|
| 260 |
+
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
|
| 261 |
+
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
|
| 262 |
+
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
|
| 263 |
+
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
|
| 264 |
+
]
|
| 265 |
+
last_channel = adjust_channels(1024 // reduce_divider) # C5
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError(f"Unsupported model type {arch}")
|
| 268 |
+
|
| 269 |
+
return inverted_residual_setting, last_channel
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _mobilenet_v3(
|
| 273 |
+
inverted_residual_setting: List[InvertedResidualConfig],
|
| 274 |
+
last_channel: int,
|
| 275 |
+
weights: Optional[WeightsEnum],
|
| 276 |
+
progress: bool,
|
| 277 |
+
**kwargs: Any,
|
| 278 |
+
) -> MobileNetV3:
|
| 279 |
+
if weights is not None:
|
| 280 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 281 |
+
|
| 282 |
+
model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
|
| 283 |
+
|
| 284 |
+
if weights is not None:
|
| 285 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 286 |
+
|
| 287 |
+
return model
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
_COMMON_META = {
|
| 291 |
+
"min_size": (1, 1),
|
| 292 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class MobileNet_V3_Large_Weights(WeightsEnum):
|
| 297 |
+
IMAGENET1K_V1 = Weights(
|
| 298 |
+
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
| 299 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 300 |
+
meta={
|
| 301 |
+
**_COMMON_META,
|
| 302 |
+
"num_params": 5483032,
|
| 303 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
|
| 304 |
+
"_metrics": {
|
| 305 |
+
"ImageNet-1K": {
|
| 306 |
+
"acc@1": 74.042,
|
| 307 |
+
"acc@5": 91.340,
|
| 308 |
+
}
|
| 309 |
+
},
|
| 310 |
+
"_ops": 0.217,
|
| 311 |
+
"_file_size": 21.114,
|
| 312 |
+
"_docs": """These weights were trained from scratch by using a simple training recipe.""",
|
| 313 |
+
},
|
| 314 |
+
)
|
| 315 |
+
IMAGENET1K_V2 = Weights(
|
| 316 |
+
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
|
| 317 |
+
transforms=partial(ImageClassification, crop_size=224, resize_size=232),
|
| 318 |
+
meta={
|
| 319 |
+
**_COMMON_META,
|
| 320 |
+
"num_params": 5483032,
|
| 321 |
+
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
|
| 322 |
+
"_metrics": {
|
| 323 |
+
"ImageNet-1K": {
|
| 324 |
+
"acc@1": 75.274,
|
| 325 |
+
"acc@5": 92.566,
|
| 326 |
+
}
|
| 327 |
+
},
|
| 328 |
+
"_ops": 0.217,
|
| 329 |
+
"_file_size": 21.107,
|
| 330 |
+
"_docs": """
|
| 331 |
+
These weights improve marginally upon the results of the original paper by using a modified version of
|
| 332 |
+
TorchVision's `new training recipe
|
| 333 |
+
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
|
| 334 |
+
""",
|
| 335 |
+
},
|
| 336 |
+
)
|
| 337 |
+
DEFAULT = IMAGENET1K_V2
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class MobileNet_V3_Small_Weights(WeightsEnum):
|
| 341 |
+
IMAGENET1K_V1 = Weights(
|
| 342 |
+
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
| 343 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 344 |
+
meta={
|
| 345 |
+
**_COMMON_META,
|
| 346 |
+
"num_params": 2542856,
|
| 347 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
|
| 348 |
+
"_metrics": {
|
| 349 |
+
"ImageNet-1K": {
|
| 350 |
+
"acc@1": 67.668,
|
| 351 |
+
"acc@5": 87.402,
|
| 352 |
+
}
|
| 353 |
+
},
|
| 354 |
+
"_ops": 0.057,
|
| 355 |
+
"_file_size": 9.829,
|
| 356 |
+
"_docs": """
|
| 357 |
+
These weights improve upon the results of the original paper by using a simple training recipe.
|
| 358 |
+
""",
|
| 359 |
+
},
|
| 360 |
+
)
|
| 361 |
+
DEFAULT = IMAGENET1K_V1
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
@register_model()
|
| 365 |
+
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
|
| 366 |
+
def mobilenet_v3_large(
|
| 367 |
+
*, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
|
| 368 |
+
) -> MobileNetV3:
|
| 369 |
+
"""
|
| 370 |
+
Constructs a large MobileNetV3 architecture from
|
| 371 |
+
`Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
weights (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
|
| 375 |
+
pretrained weights to use. See
|
| 376 |
+
:class:`~torchvision.models.MobileNet_V3_Large_Weights` below for
|
| 377 |
+
more details, and possible values. By default, no pre-trained
|
| 378 |
+
weights are used.
|
| 379 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 380 |
+
download to stderr. Default is True.
|
| 381 |
+
**kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
|
| 382 |
+
base class. Please refer to the `source code
|
| 383 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
|
| 384 |
+
for more details about this class.
|
| 385 |
+
|
| 386 |
+
.. autoclass:: torchvision.models.MobileNet_V3_Large_Weights
|
| 387 |
+
:members:
|
| 388 |
+
"""
|
| 389 |
+
weights = MobileNet_V3_Large_Weights.verify(weights)
|
| 390 |
+
|
| 391 |
+
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
|
| 392 |
+
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@register_model()
|
| 396 |
+
@handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
|
| 397 |
+
def mobilenet_v3_small(
|
| 398 |
+
*, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
|
| 399 |
+
) -> MobileNetV3:
|
| 400 |
+
"""
|
| 401 |
+
Constructs a small MobileNetV3 architecture from
|
| 402 |
+
`Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
weights (:class:`~torchvision.models.MobileNet_V3_Small_Weights`, optional): The
|
| 406 |
+
pretrained weights to use. See
|
| 407 |
+
:class:`~torchvision.models.MobileNet_V3_Small_Weights` below for
|
| 408 |
+
more details, and possible values. By default, no pre-trained
|
| 409 |
+
weights are used.
|
| 410 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 411 |
+
download to stderr. Default is True.
|
| 412 |
+
**kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
|
| 413 |
+
base class. Please refer to the `source code
|
| 414 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
|
| 415 |
+
for more details about this class.
|
| 416 |
+
|
| 417 |
+
.. autoclass:: torchvision.models.MobileNet_V3_Small_Weights
|
| 418 |
+
:members:
|
| 419 |
+
"""
|
| 420 |
+
weights = MobileNet_V3_Small_Weights.verify(weights)
|
| 421 |
+
|
| 422 |
+
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
|
| 423 |
+
return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
|
.venv/lib/python3.11/site-packages/torchvision/models/squeezenet.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.init as init
|
| 7 |
+
|
| 8 |
+
from ..transforms._presets import ImageClassification
|
| 9 |
+
from ..utils import _log_api_usage_once
|
| 10 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 11 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 12 |
+
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
__all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Fire(nn.Module):
|
| 19 |
+
def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.inplanes = inplanes
|
| 22 |
+
self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
|
| 23 |
+
self.squeeze_activation = nn.ReLU(inplace=True)
|
| 24 |
+
self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
|
| 25 |
+
self.expand1x1_activation = nn.ReLU(inplace=True)
|
| 26 |
+
self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
|
| 27 |
+
self.expand3x3_activation = nn.ReLU(inplace=True)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
x = self.squeeze_activation(self.squeeze(x))
|
| 31 |
+
return torch.cat(
|
| 32 |
+
[self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SqueezeNet(nn.Module):
|
| 37 |
+
def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
_log_api_usage_once(self)
|
| 40 |
+
self.num_classes = num_classes
|
| 41 |
+
if version == "1_0":
|
| 42 |
+
self.features = nn.Sequential(
|
| 43 |
+
nn.Conv2d(3, 96, kernel_size=7, stride=2),
|
| 44 |
+
nn.ReLU(inplace=True),
|
| 45 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
| 46 |
+
Fire(96, 16, 64, 64),
|
| 47 |
+
Fire(128, 16, 64, 64),
|
| 48 |
+
Fire(128, 32, 128, 128),
|
| 49 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
| 50 |
+
Fire(256, 32, 128, 128),
|
| 51 |
+
Fire(256, 48, 192, 192),
|
| 52 |
+
Fire(384, 48, 192, 192),
|
| 53 |
+
Fire(384, 64, 256, 256),
|
| 54 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
| 55 |
+
Fire(512, 64, 256, 256),
|
| 56 |
+
)
|
| 57 |
+
elif version == "1_1":
|
| 58 |
+
self.features = nn.Sequential(
|
| 59 |
+
nn.Conv2d(3, 64, kernel_size=3, stride=2),
|
| 60 |
+
nn.ReLU(inplace=True),
|
| 61 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
| 62 |
+
Fire(64, 16, 64, 64),
|
| 63 |
+
Fire(128, 16, 64, 64),
|
| 64 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
| 65 |
+
Fire(128, 32, 128, 128),
|
| 66 |
+
Fire(256, 32, 128, 128),
|
| 67 |
+
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
|
| 68 |
+
Fire(256, 48, 192, 192),
|
| 69 |
+
Fire(384, 48, 192, 192),
|
| 70 |
+
Fire(384, 64, 256, 256),
|
| 71 |
+
Fire(512, 64, 256, 256),
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
# FIXME: Is this needed? SqueezeNet should only be called from the
|
| 75 |
+
# FIXME: squeezenet1_x() functions
|
| 76 |
+
# FIXME: This checking is not done for the other models
|
| 77 |
+
raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
|
| 78 |
+
|
| 79 |
+
# Final convolution is initialized differently from the rest
|
| 80 |
+
final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
|
| 81 |
+
self.classifier = nn.Sequential(
|
| 82 |
+
nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
for m in self.modules():
|
| 86 |
+
if isinstance(m, nn.Conv2d):
|
| 87 |
+
if m is final_conv:
|
| 88 |
+
init.normal_(m.weight, mean=0.0, std=0.01)
|
| 89 |
+
else:
|
| 90 |
+
init.kaiming_uniform_(m.weight)
|
| 91 |
+
if m.bias is not None:
|
| 92 |
+
init.constant_(m.bias, 0)
|
| 93 |
+
|
| 94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
x = self.features(x)
|
| 96 |
+
x = self.classifier(x)
|
| 97 |
+
return torch.flatten(x, 1)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _squeezenet(
|
| 101 |
+
version: str,
|
| 102 |
+
weights: Optional[WeightsEnum],
|
| 103 |
+
progress: bool,
|
| 104 |
+
**kwargs: Any,
|
| 105 |
+
) -> SqueezeNet:
|
| 106 |
+
if weights is not None:
|
| 107 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 108 |
+
|
| 109 |
+
model = SqueezeNet(version, **kwargs)
|
| 110 |
+
|
| 111 |
+
if weights is not None:
|
| 112 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 113 |
+
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
_COMMON_META = {
|
| 118 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 119 |
+
"recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
|
| 120 |
+
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class SqueezeNet1_0_Weights(WeightsEnum):
|
| 125 |
+
IMAGENET1K_V1 = Weights(
|
| 126 |
+
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
|
| 127 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 128 |
+
meta={
|
| 129 |
+
**_COMMON_META,
|
| 130 |
+
"min_size": (21, 21),
|
| 131 |
+
"num_params": 1248424,
|
| 132 |
+
"_metrics": {
|
| 133 |
+
"ImageNet-1K": {
|
| 134 |
+
"acc@1": 58.092,
|
| 135 |
+
"acc@5": 80.420,
|
| 136 |
+
}
|
| 137 |
+
},
|
| 138 |
+
"_ops": 0.819,
|
| 139 |
+
"_file_size": 4.778,
|
| 140 |
+
},
|
| 141 |
+
)
|
| 142 |
+
DEFAULT = IMAGENET1K_V1
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class SqueezeNet1_1_Weights(WeightsEnum):
|
| 146 |
+
IMAGENET1K_V1 = Weights(
|
| 147 |
+
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
|
| 148 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 149 |
+
meta={
|
| 150 |
+
**_COMMON_META,
|
| 151 |
+
"min_size": (17, 17),
|
| 152 |
+
"num_params": 1235496,
|
| 153 |
+
"_metrics": {
|
| 154 |
+
"ImageNet-1K": {
|
| 155 |
+
"acc@1": 58.178,
|
| 156 |
+
"acc@5": 80.624,
|
| 157 |
+
}
|
| 158 |
+
},
|
| 159 |
+
"_ops": 0.349,
|
| 160 |
+
"_file_size": 4.729,
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
+
DEFAULT = IMAGENET1K_V1
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
@register_model()
|
| 167 |
+
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
|
| 168 |
+
def squeezenet1_0(
|
| 169 |
+
*, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
|
| 170 |
+
) -> SqueezeNet:
|
| 171 |
+
"""SqueezeNet model architecture from the `SqueezeNet: AlexNet-level
|
| 172 |
+
accuracy with 50x fewer parameters and <0.5MB model size
|
| 173 |
+
<https://arxiv.org/abs/1602.07360>`_ paper.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
weights (:class:`~torchvision.models.SqueezeNet1_0_Weights`, optional): The
|
| 177 |
+
pretrained weights to use. See
|
| 178 |
+
:class:`~torchvision.models.SqueezeNet1_0_Weights` below for
|
| 179 |
+
more details, and possible values. By default, no pre-trained
|
| 180 |
+
weights are used.
|
| 181 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 182 |
+
download to stderr. Default is True.
|
| 183 |
+
**kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
|
| 184 |
+
base class. Please refer to the `source code
|
| 185 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
|
| 186 |
+
for more details about this class.
|
| 187 |
+
|
| 188 |
+
.. autoclass:: torchvision.models.SqueezeNet1_0_Weights
|
| 189 |
+
:members:
|
| 190 |
+
"""
|
| 191 |
+
weights = SqueezeNet1_0_Weights.verify(weights)
|
| 192 |
+
return _squeezenet("1_0", weights, progress, **kwargs)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@register_model()
|
| 196 |
+
@handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
|
| 197 |
+
def squeezenet1_1(
|
| 198 |
+
*, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
|
| 199 |
+
) -> SqueezeNet:
|
| 200 |
+
"""SqueezeNet 1.1 model from the `official SqueezeNet repo
|
| 201 |
+
<https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
|
| 202 |
+
|
| 203 |
+
SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
|
| 204 |
+
than SqueezeNet 1.0, without sacrificing accuracy.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
weights (:class:`~torchvision.models.SqueezeNet1_1_Weights`, optional): The
|
| 208 |
+
pretrained weights to use. See
|
| 209 |
+
:class:`~torchvision.models.SqueezeNet1_1_Weights` below for
|
| 210 |
+
more details, and possible values. By default, no pre-trained
|
| 211 |
+
weights are used.
|
| 212 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 213 |
+
download to stderr. Default is True.
|
| 214 |
+
**kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
|
| 215 |
+
base class. Please refer to the `source code
|
| 216 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
|
| 217 |
+
for more details about this class.
|
| 218 |
+
|
| 219 |
+
.. autoclass:: torchvision.models.SqueezeNet1_1_Weights
|
| 220 |
+
:members:
|
| 221 |
+
"""
|
| 222 |
+
weights = SqueezeNet1_1_Weights.verify(weights)
|
| 223 |
+
return _squeezenet("1_1", weights, progress, **kwargs)
|
.venv/lib/python3.11/site-packages/torchvision/models/vgg.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from typing import Any, cast, Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from ..transforms._presets import ImageClassification
|
| 8 |
+
from ..utils import _log_api_usage_once
|
| 9 |
+
from ._api import register_model, Weights, WeightsEnum
|
| 10 |
+
from ._meta import _IMAGENET_CATEGORIES
|
| 11 |
+
from ._utils import _ovewrite_named_param, handle_legacy_interface
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"VGG",
|
| 16 |
+
"VGG11_Weights",
|
| 17 |
+
"VGG11_BN_Weights",
|
| 18 |
+
"VGG13_Weights",
|
| 19 |
+
"VGG13_BN_Weights",
|
| 20 |
+
"VGG16_Weights",
|
| 21 |
+
"VGG16_BN_Weights",
|
| 22 |
+
"VGG19_Weights",
|
| 23 |
+
"VGG19_BN_Weights",
|
| 24 |
+
"vgg11",
|
| 25 |
+
"vgg11_bn",
|
| 26 |
+
"vgg13",
|
| 27 |
+
"vgg13_bn",
|
| 28 |
+
"vgg16",
|
| 29 |
+
"vgg16_bn",
|
| 30 |
+
"vgg19",
|
| 31 |
+
"vgg19_bn",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class VGG(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
_log_api_usage_once(self)
|
| 41 |
+
self.features = features
|
| 42 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
| 43 |
+
self.classifier = nn.Sequential(
|
| 44 |
+
nn.Linear(512 * 7 * 7, 4096),
|
| 45 |
+
nn.ReLU(True),
|
| 46 |
+
nn.Dropout(p=dropout),
|
| 47 |
+
nn.Linear(4096, 4096),
|
| 48 |
+
nn.ReLU(True),
|
| 49 |
+
nn.Dropout(p=dropout),
|
| 50 |
+
nn.Linear(4096, num_classes),
|
| 51 |
+
)
|
| 52 |
+
if init_weights:
|
| 53 |
+
for m in self.modules():
|
| 54 |
+
if isinstance(m, nn.Conv2d):
|
| 55 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 56 |
+
if m.bias is not None:
|
| 57 |
+
nn.init.constant_(m.bias, 0)
|
| 58 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 59 |
+
nn.init.constant_(m.weight, 1)
|
| 60 |
+
nn.init.constant_(m.bias, 0)
|
| 61 |
+
elif isinstance(m, nn.Linear):
|
| 62 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 63 |
+
nn.init.constant_(m.bias, 0)
|
| 64 |
+
|
| 65 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 66 |
+
x = self.features(x)
|
| 67 |
+
x = self.avgpool(x)
|
| 68 |
+
x = torch.flatten(x, 1)
|
| 69 |
+
x = self.classifier(x)
|
| 70 |
+
return x
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
|
| 74 |
+
layers: List[nn.Module] = []
|
| 75 |
+
in_channels = 3
|
| 76 |
+
for v in cfg:
|
| 77 |
+
if v == "M":
|
| 78 |
+
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
|
| 79 |
+
else:
|
| 80 |
+
v = cast(int, v)
|
| 81 |
+
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
|
| 82 |
+
if batch_norm:
|
| 83 |
+
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
|
| 84 |
+
else:
|
| 85 |
+
layers += [conv2d, nn.ReLU(inplace=True)]
|
| 86 |
+
in_channels = v
|
| 87 |
+
return nn.Sequential(*layers)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
cfgs: Dict[str, List[Union[str, int]]] = {
|
| 91 |
+
"A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
| 92 |
+
"B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
| 93 |
+
"D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
|
| 94 |
+
"E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
|
| 99 |
+
if weights is not None:
|
| 100 |
+
kwargs["init_weights"] = False
|
| 101 |
+
if weights.meta["categories"] is not None:
|
| 102 |
+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
|
| 103 |
+
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
|
| 104 |
+
if weights is not None:
|
| 105 |
+
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
_COMMON_META = {
|
| 110 |
+
"min_size": (32, 32),
|
| 111 |
+
"categories": _IMAGENET_CATEGORIES,
|
| 112 |
+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
|
| 113 |
+
"_docs": """These weights were trained from scratch by using a simplified training recipe.""",
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class VGG11_Weights(WeightsEnum):
|
| 118 |
+
IMAGENET1K_V1 = Weights(
|
| 119 |
+
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
|
| 120 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 121 |
+
meta={
|
| 122 |
+
**_COMMON_META,
|
| 123 |
+
"num_params": 132863336,
|
| 124 |
+
"_metrics": {
|
| 125 |
+
"ImageNet-1K": {
|
| 126 |
+
"acc@1": 69.020,
|
| 127 |
+
"acc@5": 88.628,
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"_ops": 7.609,
|
| 131 |
+
"_file_size": 506.84,
|
| 132 |
+
},
|
| 133 |
+
)
|
| 134 |
+
DEFAULT = IMAGENET1K_V1
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class VGG11_BN_Weights(WeightsEnum):
|
| 138 |
+
IMAGENET1K_V1 = Weights(
|
| 139 |
+
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
| 140 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 141 |
+
meta={
|
| 142 |
+
**_COMMON_META,
|
| 143 |
+
"num_params": 132868840,
|
| 144 |
+
"_metrics": {
|
| 145 |
+
"ImageNet-1K": {
|
| 146 |
+
"acc@1": 70.370,
|
| 147 |
+
"acc@5": 89.810,
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"_ops": 7.609,
|
| 151 |
+
"_file_size": 506.881,
|
| 152 |
+
},
|
| 153 |
+
)
|
| 154 |
+
DEFAULT = IMAGENET1K_V1
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class VGG13_Weights(WeightsEnum):
|
| 158 |
+
IMAGENET1K_V1 = Weights(
|
| 159 |
+
url="https://download.pytorch.org/models/vgg13-19584684.pth",
|
| 160 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 161 |
+
meta={
|
| 162 |
+
**_COMMON_META,
|
| 163 |
+
"num_params": 133047848,
|
| 164 |
+
"_metrics": {
|
| 165 |
+
"ImageNet-1K": {
|
| 166 |
+
"acc@1": 69.928,
|
| 167 |
+
"acc@5": 89.246,
|
| 168 |
+
}
|
| 169 |
+
},
|
| 170 |
+
"_ops": 11.308,
|
| 171 |
+
"_file_size": 507.545,
|
| 172 |
+
},
|
| 173 |
+
)
|
| 174 |
+
DEFAULT = IMAGENET1K_V1
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class VGG13_BN_Weights(WeightsEnum):
|
| 178 |
+
IMAGENET1K_V1 = Weights(
|
| 179 |
+
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
| 180 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 181 |
+
meta={
|
| 182 |
+
**_COMMON_META,
|
| 183 |
+
"num_params": 133053736,
|
| 184 |
+
"_metrics": {
|
| 185 |
+
"ImageNet-1K": {
|
| 186 |
+
"acc@1": 71.586,
|
| 187 |
+
"acc@5": 90.374,
|
| 188 |
+
}
|
| 189 |
+
},
|
| 190 |
+
"_ops": 11.308,
|
| 191 |
+
"_file_size": 507.59,
|
| 192 |
+
},
|
| 193 |
+
)
|
| 194 |
+
DEFAULT = IMAGENET1K_V1
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class VGG16_Weights(WeightsEnum):
|
| 198 |
+
IMAGENET1K_V1 = Weights(
|
| 199 |
+
url="https://download.pytorch.org/models/vgg16-397923af.pth",
|
| 200 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 201 |
+
meta={
|
| 202 |
+
**_COMMON_META,
|
| 203 |
+
"num_params": 138357544,
|
| 204 |
+
"_metrics": {
|
| 205 |
+
"ImageNet-1K": {
|
| 206 |
+
"acc@1": 71.592,
|
| 207 |
+
"acc@5": 90.382,
|
| 208 |
+
}
|
| 209 |
+
},
|
| 210 |
+
"_ops": 15.47,
|
| 211 |
+
"_file_size": 527.796,
|
| 212 |
+
},
|
| 213 |
+
)
|
| 214 |
+
IMAGENET1K_FEATURES = Weights(
|
| 215 |
+
# Weights ported from https://github.com/amdegroot/ssd.pytorch/
|
| 216 |
+
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
|
| 217 |
+
transforms=partial(
|
| 218 |
+
ImageClassification,
|
| 219 |
+
crop_size=224,
|
| 220 |
+
mean=(0.48235, 0.45882, 0.40784),
|
| 221 |
+
std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
|
| 222 |
+
),
|
| 223 |
+
meta={
|
| 224 |
+
**_COMMON_META,
|
| 225 |
+
"num_params": 138357544,
|
| 226 |
+
"categories": None,
|
| 227 |
+
"recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
|
| 228 |
+
"_metrics": {
|
| 229 |
+
"ImageNet-1K": {
|
| 230 |
+
"acc@1": float("nan"),
|
| 231 |
+
"acc@5": float("nan"),
|
| 232 |
+
}
|
| 233 |
+
},
|
| 234 |
+
"_ops": 15.47,
|
| 235 |
+
"_file_size": 527.802,
|
| 236 |
+
"_docs": """
|
| 237 |
+
These weights can't be used for classification because they are missing values in the `classifier`
|
| 238 |
+
module. Only the `features` module has valid values and can be used for feature extraction. The weights
|
| 239 |
+
were trained using the original input standardization method as described in the paper.
|
| 240 |
+
""",
|
| 241 |
+
},
|
| 242 |
+
)
|
| 243 |
+
DEFAULT = IMAGENET1K_V1
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class VGG16_BN_Weights(WeightsEnum):
|
| 247 |
+
IMAGENET1K_V1 = Weights(
|
| 248 |
+
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
| 249 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 250 |
+
meta={
|
| 251 |
+
**_COMMON_META,
|
| 252 |
+
"num_params": 138365992,
|
| 253 |
+
"_metrics": {
|
| 254 |
+
"ImageNet-1K": {
|
| 255 |
+
"acc@1": 73.360,
|
| 256 |
+
"acc@5": 91.516,
|
| 257 |
+
}
|
| 258 |
+
},
|
| 259 |
+
"_ops": 15.47,
|
| 260 |
+
"_file_size": 527.866,
|
| 261 |
+
},
|
| 262 |
+
)
|
| 263 |
+
DEFAULT = IMAGENET1K_V1
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class VGG19_Weights(WeightsEnum):
|
| 267 |
+
IMAGENET1K_V1 = Weights(
|
| 268 |
+
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
| 269 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 270 |
+
meta={
|
| 271 |
+
**_COMMON_META,
|
| 272 |
+
"num_params": 143667240,
|
| 273 |
+
"_metrics": {
|
| 274 |
+
"ImageNet-1K": {
|
| 275 |
+
"acc@1": 72.376,
|
| 276 |
+
"acc@5": 90.876,
|
| 277 |
+
}
|
| 278 |
+
},
|
| 279 |
+
"_ops": 19.632,
|
| 280 |
+
"_file_size": 548.051,
|
| 281 |
+
},
|
| 282 |
+
)
|
| 283 |
+
DEFAULT = IMAGENET1K_V1
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class VGG19_BN_Weights(WeightsEnum):
|
| 287 |
+
IMAGENET1K_V1 = Weights(
|
| 288 |
+
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
|
| 289 |
+
transforms=partial(ImageClassification, crop_size=224),
|
| 290 |
+
meta={
|
| 291 |
+
**_COMMON_META,
|
| 292 |
+
"num_params": 143678248,
|
| 293 |
+
"_metrics": {
|
| 294 |
+
"ImageNet-1K": {
|
| 295 |
+
"acc@1": 74.218,
|
| 296 |
+
"acc@5": 91.842,
|
| 297 |
+
}
|
| 298 |
+
},
|
| 299 |
+
"_ops": 19.632,
|
| 300 |
+
"_file_size": 548.143,
|
| 301 |
+
},
|
| 302 |
+
)
|
| 303 |
+
DEFAULT = IMAGENET1K_V1
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@register_model()
|
| 307 |
+
@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
|
| 308 |
+
def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 309 |
+
"""VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
|
| 313 |
+
pretrained weights to use. See
|
| 314 |
+
:class:`~torchvision.models.VGG11_Weights` below for
|
| 315 |
+
more details, and possible values. By default, no pre-trained
|
| 316 |
+
weights are used.
|
| 317 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 318 |
+
download to stderr. Default is True.
|
| 319 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 320 |
+
base class. Please refer to the `source code
|
| 321 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 322 |
+
for more details about this class.
|
| 323 |
+
|
| 324 |
+
.. autoclass:: torchvision.models.VGG11_Weights
|
| 325 |
+
:members:
|
| 326 |
+
"""
|
| 327 |
+
weights = VGG11_Weights.verify(weights)
|
| 328 |
+
|
| 329 |
+
return _vgg("A", False, weights, progress, **kwargs)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
@register_model()
|
| 333 |
+
@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
|
| 334 |
+
def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 335 |
+
"""VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
|
| 339 |
+
pretrained weights to use. See
|
| 340 |
+
:class:`~torchvision.models.VGG11_BN_Weights` below for
|
| 341 |
+
more details, and possible values. By default, no pre-trained
|
| 342 |
+
weights are used.
|
| 343 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 344 |
+
download to stderr. Default is True.
|
| 345 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 346 |
+
base class. Please refer to the `source code
|
| 347 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 348 |
+
for more details about this class.
|
| 349 |
+
|
| 350 |
+
.. autoclass:: torchvision.models.VGG11_BN_Weights
|
| 351 |
+
:members:
|
| 352 |
+
"""
|
| 353 |
+
weights = VGG11_BN_Weights.verify(weights)
|
| 354 |
+
|
| 355 |
+
return _vgg("A", True, weights, progress, **kwargs)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@register_model()
|
| 359 |
+
@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
|
| 360 |
+
def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 361 |
+
"""VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
|
| 365 |
+
pretrained weights to use. See
|
| 366 |
+
:class:`~torchvision.models.VGG13_Weights` below for
|
| 367 |
+
more details, and possible values. By default, no pre-trained
|
| 368 |
+
weights are used.
|
| 369 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 370 |
+
download to stderr. Default is True.
|
| 371 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 372 |
+
base class. Please refer to the `source code
|
| 373 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 374 |
+
for more details about this class.
|
| 375 |
+
|
| 376 |
+
.. autoclass:: torchvision.models.VGG13_Weights
|
| 377 |
+
:members:
|
| 378 |
+
"""
|
| 379 |
+
weights = VGG13_Weights.verify(weights)
|
| 380 |
+
|
| 381 |
+
return _vgg("B", False, weights, progress, **kwargs)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
@register_model()
|
| 385 |
+
@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
|
| 386 |
+
def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 387 |
+
"""VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
|
| 391 |
+
pretrained weights to use. See
|
| 392 |
+
:class:`~torchvision.models.VGG13_BN_Weights` below for
|
| 393 |
+
more details, and possible values. By default, no pre-trained
|
| 394 |
+
weights are used.
|
| 395 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 396 |
+
download to stderr. Default is True.
|
| 397 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 398 |
+
base class. Please refer to the `source code
|
| 399 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 400 |
+
for more details about this class.
|
| 401 |
+
|
| 402 |
+
.. autoclass:: torchvision.models.VGG13_BN_Weights
|
| 403 |
+
:members:
|
| 404 |
+
"""
|
| 405 |
+
weights = VGG13_BN_Weights.verify(weights)
|
| 406 |
+
|
| 407 |
+
return _vgg("B", True, weights, progress, **kwargs)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
@register_model()
|
| 411 |
+
@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
|
| 412 |
+
def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 413 |
+
"""VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
|
| 417 |
+
pretrained weights to use. See
|
| 418 |
+
:class:`~torchvision.models.VGG16_Weights` below for
|
| 419 |
+
more details, and possible values. By default, no pre-trained
|
| 420 |
+
weights are used.
|
| 421 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 422 |
+
download to stderr. Default is True.
|
| 423 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 424 |
+
base class. Please refer to the `source code
|
| 425 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 426 |
+
for more details about this class.
|
| 427 |
+
|
| 428 |
+
.. autoclass:: torchvision.models.VGG16_Weights
|
| 429 |
+
:members:
|
| 430 |
+
"""
|
| 431 |
+
weights = VGG16_Weights.verify(weights)
|
| 432 |
+
|
| 433 |
+
return _vgg("D", False, weights, progress, **kwargs)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
@register_model()
|
| 437 |
+
@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
|
| 438 |
+
def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 439 |
+
"""VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
|
| 443 |
+
pretrained weights to use. See
|
| 444 |
+
:class:`~torchvision.models.VGG16_BN_Weights` below for
|
| 445 |
+
more details, and possible values. By default, no pre-trained
|
| 446 |
+
weights are used.
|
| 447 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 448 |
+
download to stderr. Default is True.
|
| 449 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 450 |
+
base class. Please refer to the `source code
|
| 451 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 452 |
+
for more details about this class.
|
| 453 |
+
|
| 454 |
+
.. autoclass:: torchvision.models.VGG16_BN_Weights
|
| 455 |
+
:members:
|
| 456 |
+
"""
|
| 457 |
+
weights = VGG16_BN_Weights.verify(weights)
|
| 458 |
+
|
| 459 |
+
return _vgg("D", True, weights, progress, **kwargs)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@register_model()
|
| 463 |
+
@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
|
| 464 |
+
def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 465 |
+
"""VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
|
| 469 |
+
pretrained weights to use. See
|
| 470 |
+
:class:`~torchvision.models.VGG19_Weights` below for
|
| 471 |
+
more details, and possible values. By default, no pre-trained
|
| 472 |
+
weights are used.
|
| 473 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 474 |
+
download to stderr. Default is True.
|
| 475 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 476 |
+
base class. Please refer to the `source code
|
| 477 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 478 |
+
for more details about this class.
|
| 479 |
+
|
| 480 |
+
.. autoclass:: torchvision.models.VGG19_Weights
|
| 481 |
+
:members:
|
| 482 |
+
"""
|
| 483 |
+
weights = VGG19_Weights.verify(weights)
|
| 484 |
+
|
| 485 |
+
return _vgg("E", False, weights, progress, **kwargs)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
@register_model()
|
| 489 |
+
@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
|
| 490 |
+
def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
|
| 491 |
+
"""VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
|
| 495 |
+
pretrained weights to use. See
|
| 496 |
+
:class:`~torchvision.models.VGG19_BN_Weights` below for
|
| 497 |
+
more details, and possible values. By default, no pre-trained
|
| 498 |
+
weights are used.
|
| 499 |
+
progress (bool, optional): If True, displays a progress bar of the
|
| 500 |
+
download to stderr. Default is True.
|
| 501 |
+
**kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
|
| 502 |
+
base class. Please refer to the `source code
|
| 503 |
+
<https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
|
| 504 |
+
for more details about this class.
|
| 505 |
+
|
| 506 |
+
.. autoclass:: torchvision.models.VGG19_BN_Weights
|
| 507 |
+
:members:
|
| 508 |
+
"""
|
| 509 |
+
weights = VGG19_BN_Weights.verify(weights)
|
| 510 |
+
|
| 511 |
+
return _vgg("E", True, weights, progress, **kwargs)
|
.venv/lib/python3.11/site-packages/torchvision/ops/__init__.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._register_onnx_ops import _register_custom_op
|
| 2 |
+
from .boxes import (
|
| 3 |
+
batched_nms,
|
| 4 |
+
box_area,
|
| 5 |
+
box_convert,
|
| 6 |
+
box_iou,
|
| 7 |
+
clip_boxes_to_image,
|
| 8 |
+
complete_box_iou,
|
| 9 |
+
distance_box_iou,
|
| 10 |
+
generalized_box_iou,
|
| 11 |
+
masks_to_boxes,
|
| 12 |
+
nms,
|
| 13 |
+
remove_small_boxes,
|
| 14 |
+
)
|
| 15 |
+
from .ciou_loss import complete_box_iou_loss
|
| 16 |
+
from .deform_conv import deform_conv2d, DeformConv2d
|
| 17 |
+
from .diou_loss import distance_box_iou_loss
|
| 18 |
+
from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d
|
| 19 |
+
from .feature_pyramid_network import FeaturePyramidNetwork
|
| 20 |
+
from .focal_loss import sigmoid_focal_loss
|
| 21 |
+
from .giou_loss import generalized_box_iou_loss
|
| 22 |
+
from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation
|
| 23 |
+
from .poolers import MultiScaleRoIAlign
|
| 24 |
+
from .ps_roi_align import ps_roi_align, PSRoIAlign
|
| 25 |
+
from .ps_roi_pool import ps_roi_pool, PSRoIPool
|
| 26 |
+
from .roi_align import roi_align, RoIAlign
|
| 27 |
+
from .roi_pool import roi_pool, RoIPool
|
| 28 |
+
from .stochastic_depth import stochastic_depth, StochasticDepth
|
| 29 |
+
|
| 30 |
+
_register_custom_op()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"masks_to_boxes",
|
| 35 |
+
"deform_conv2d",
|
| 36 |
+
"DeformConv2d",
|
| 37 |
+
"nms",
|
| 38 |
+
"batched_nms",
|
| 39 |
+
"remove_small_boxes",
|
| 40 |
+
"clip_boxes_to_image",
|
| 41 |
+
"box_convert",
|
| 42 |
+
"box_area",
|
| 43 |
+
"box_iou",
|
| 44 |
+
"generalized_box_iou",
|
| 45 |
+
"distance_box_iou",
|
| 46 |
+
"complete_box_iou",
|
| 47 |
+
"roi_align",
|
| 48 |
+
"RoIAlign",
|
| 49 |
+
"roi_pool",
|
| 50 |
+
"RoIPool",
|
| 51 |
+
"ps_roi_align",
|
| 52 |
+
"PSRoIAlign",
|
| 53 |
+
"ps_roi_pool",
|
| 54 |
+
"PSRoIPool",
|
| 55 |
+
"MultiScaleRoIAlign",
|
| 56 |
+
"FeaturePyramidNetwork",
|
| 57 |
+
"sigmoid_focal_loss",
|
| 58 |
+
"stochastic_depth",
|
| 59 |
+
"StochasticDepth",
|
| 60 |
+
"FrozenBatchNorm2d",
|
| 61 |
+
"Conv2dNormActivation",
|
| 62 |
+
"Conv3dNormActivation",
|
| 63 |
+
"SqueezeExcitation",
|
| 64 |
+
"MLP",
|
| 65 |
+
"Permute",
|
| 66 |
+
"generalized_box_iou_loss",
|
| 67 |
+
"distance_box_iou_loss",
|
| 68 |
+
"complete_box_iou_loss",
|
| 69 |
+
"drop_block2d",
|
| 70 |
+
"DropBlock2d",
|
| 71 |
+
"drop_block3d",
|
| 72 |
+
"DropBlock3d",
|
| 73 |
+
]
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-311.pyc
ADDED
|
Binary file (3.46 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-311.pyc
ADDED
|
Binary file (6.19 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_utils.cpython-311.pyc
ADDED
|
Binary file (7.17 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/boxes.cpython-311.pyc
ADDED
|
Binary file (23.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-311.pyc
ADDED
|
Binary file (3.94 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-311.pyc
ADDED
|
Binary file (9.67 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-311.pyc
ADDED
|
Binary file (4.3 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/drop_block.cpython-311.pyc
ADDED
|
Binary file (9.13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-311.pyc
ADDED
|
Binary file (2.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-311.pyc
ADDED
|
Binary file (3.65 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/misc.cpython-311.pyc
ADDED
|
Binary file (20 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/poolers.cpython-311.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|