koichi12 commited on
Commit
5fce27e
·
verified ·
1 Parent(s): 8ea62b1

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/AUTHORS +19 -0
  2. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/INSTALLER +1 -0
  3. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/LICENSE +29 -0
  4. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/METADATA +98 -0
  5. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/RECORD +29 -0
  6. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/WHEEL +6 -0
  7. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/top_level.txt +1 -0
  8. .venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/zip-safe +1 -0
  9. .venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/METADATA +308 -0
  10. .venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/top_level.txt +1 -0
  11. .venv/lib/python3.11/site-packages/nvidia_cuda_cupti_cu12-12.4.127.dist-info/METADATA +35 -0
  12. .venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/LICENSE +22 -0
  13. .venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/METADATA +586 -0
  14. .venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/WHEEL +5 -0
  15. .venv/lib/python3.11/site-packages/torchvision/__init__.py +105 -0
  16. .venv/lib/python3.11/site-packages/torchvision/__pycache__/__init__.cpython-311.pyc +0 -0
  17. .venv/lib/python3.11/site-packages/torchvision/__pycache__/_internally_replaced_utils.cpython-311.pyc +0 -0
  18. .venv/lib/python3.11/site-packages/torchvision/__pycache__/_meta_registrations.cpython-311.pyc +0 -0
  19. .venv/lib/python3.11/site-packages/torchvision/__pycache__/_utils.cpython-311.pyc +0 -0
  20. .venv/lib/python3.11/site-packages/torchvision/__pycache__/extension.cpython-311.pyc +0 -0
  21. .venv/lib/python3.11/site-packages/torchvision/__pycache__/utils.cpython-311.pyc +0 -0
  22. .venv/lib/python3.11/site-packages/torchvision/__pycache__/version.cpython-311.pyc +0 -0
  23. .venv/lib/python3.11/site-packages/torchvision/_internally_replaced_utils.py +50 -0
  24. .venv/lib/python3.11/site-packages/torchvision/_meta_registrations.py +225 -0
  25. .venv/lib/python3.11/site-packages/torchvision/_utils.py +32 -0
  26. .venv/lib/python3.11/site-packages/torchvision/extension.py +92 -0
  27. .venv/lib/python3.11/site-packages/torchvision/models/alexnet.py +119 -0
  28. .venv/lib/python3.11/site-packages/torchvision/models/convnext.py +414 -0
  29. .venv/lib/python3.11/site-packages/torchvision/models/densenet.py +448 -0
  30. .venv/lib/python3.11/site-packages/torchvision/models/efficientnet.py +1131 -0
  31. .venv/lib/python3.11/site-packages/torchvision/models/googlenet.py +345 -0
  32. .venv/lib/python3.11/site-packages/torchvision/models/maxvit.py +833 -0
  33. .venv/lib/python3.11/site-packages/torchvision/models/mobilenetv3.py +423 -0
  34. .venv/lib/python3.11/site-packages/torchvision/models/squeezenet.py +223 -0
  35. .venv/lib/python3.11/site-packages/torchvision/models/vgg.py +511 -0
  36. .venv/lib/python3.11/site-packages/torchvision/ops/__init__.py +73 -0
  37. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/__init__.cpython-311.pyc +0 -0
  38. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-311.pyc +0 -0
  39. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-311.pyc +0 -0
  40. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_utils.cpython-311.pyc +0 -0
  41. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/boxes.cpython-311.pyc +0 -0
  42. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-311.pyc +0 -0
  43. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-311.pyc +0 -0
  44. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-311.pyc +0 -0
  45. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/drop_block.cpython-311.pyc +0 -0
  46. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-311.pyc +0 -0
  47. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-311.pyc +0 -0
  48. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-311.pyc +0 -0
  49. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/misc.cpython-311.pyc +0 -0
  50. .venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/poolers.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/AUTHORS ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Original author of astor/codegen.py:
2
+ * Armin Ronacher <armin.ronacher@active-4.com>
3
+
4
+ And with some modifications based on Armin's code:
5
+ * Paul Dubs <paul.dubs@gmail.com>
6
+
7
+ * Berker Peksag <berker.peksag@gmail.com>
8
+ * Patrick Maupin <pmaupin@gmail.com>
9
+ * Abhishek L <abhishek.lekshmanan@gmail.com>
10
+ * Bob Tolbert <bob@eyesopen.com>
11
+ * Whyzgeek <whyzgeek@gmail.com>
12
+ * Zack M. Davis <code@zackmdavis.net>
13
+ * Ryan Gonzalez <rymg19@gmail.com>
14
+ * Lenny Truong <leonardtruong@protonmail.com>
15
+ * Radomír Bosák <radomir.bosak@gmail.com>
16
+ * Kodi Arfer <git@arfer.net>
17
+ * Felix Yan <felixonmars@archlinux.org>
18
+ * Chris Rink <chrisrink10@gmail.com>
19
+ * Batuhan Taskaya <batuhanosmantaskaya@gmail.com>
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/LICENSE ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2012, Patrick Maupin
2
+ Copyright (c) 2013, Berker Peksag
3
+ Copyright (c) 2008, Armin Ronacher
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification,
7
+ are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation and/or
14
+ other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its contributors
17
+ may be used to endorse or promote products derived from this software without
18
+ specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
21
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
22
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
24
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
25
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
26
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
27
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/METADATA ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: astor
3
+ Version: 0.8.1
4
+ Summary: Read/rewrite/write Python ASTs
5
+ Home-page: https://github.com/berkerpeksag/astor
6
+ Author: Patrick Maupin
7
+ Author-email: pmaupin@gmail.com
8
+ License: BSD-3-Clause
9
+ Keywords: ast,codegen,PEP 8
10
+ Platform: Independent
11
+ Classifier: Development Status :: 5 - Production/Stable
12
+ Classifier: Environment :: Console
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: License :: OSI Approved :: BSD License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python
17
+ Classifier: Programming Language :: Python :: 2
18
+ Classifier: Programming Language :: Python :: 2.7
19
+ Classifier: Programming Language :: Python :: 3
20
+ Classifier: Programming Language :: Python :: 3.4
21
+ Classifier: Programming Language :: Python :: 3.5
22
+ Classifier: Programming Language :: Python :: 3.6
23
+ Classifier: Programming Language :: Python :: 3.7
24
+ Classifier: Programming Language :: Python :: 3.8
25
+ Classifier: Programming Language :: Python :: Implementation
26
+ Classifier: Programming Language :: Python :: Implementation :: CPython
27
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
28
+ Classifier: Topic :: Software Development :: Code Generators
29
+ Classifier: Topic :: Software Development :: Compilers
30
+ Requires-Python: !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7
31
+
32
+ =============================
33
+ astor -- AST observe/rewrite
34
+ =============================
35
+
36
+ :PyPI: https://pypi.org/project/astor/
37
+ :Documentation: https://astor.readthedocs.io
38
+ :Source: https://github.com/berkerpeksag/astor
39
+ :License: 3-clause BSD
40
+ :Build status:
41
+ .. image:: https://secure.travis-ci.org/berkerpeksag/astor.svg
42
+ :alt: Travis CI
43
+ :target: https://travis-ci.org/berkerpeksag/astor/
44
+
45
+ astor is designed to allow easy manipulation of Python source via the AST.
46
+
47
+ There are some other similar libraries, but astor focuses on the following areas:
48
+
49
+ - Round-trip an AST back to Python [1]_:
50
+
51
+ - Modified AST doesn't need linenumbers, ctx, etc. or otherwise
52
+ be directly compileable for the round-trip to work.
53
+ - Easy to read generated code as, well, code
54
+ - Can round-trip two different source trees to compare for functional
55
+ differences, using the astor.rtrip tool (for example, after PEP8 edits).
56
+
57
+ - Dump pretty-printing of AST
58
+
59
+ - Harder to read than round-tripped code, but more accurate to figure out what
60
+ is going on.
61
+
62
+ - Easier to read than dump from built-in AST module
63
+
64
+ - Non-recursive treewalk
65
+
66
+ - Sometimes you want a recursive treewalk (and astor supports that, starting
67
+ at any node on the tree), but sometimes you don't need to do that. astor
68
+ doesn't require you to explicitly visit sub-nodes unless you want to:
69
+
70
+ - You can add code that executes before a node's children are visited, and/or
71
+ - You can add code that executes after a node's children are visited, and/or
72
+ - You can add code that executes and keeps the node's children from being
73
+ visited (and optionally visit them yourself via a recursive call)
74
+
75
+ - Write functions to access the tree based on object names and/or attribute names
76
+ - Enjoy easy access to parent node(s) for tree rewriting
77
+
78
+ .. [1]
79
+ The decompilation back to Python is based on code originally written
80
+ by Armin Ronacher. Armin's code was well-structured, but failed on
81
+ some obscure corner cases of the Python language (and even more corner
82
+ cases when the AST changed on different versions of Python), and its
83
+ output arguably had cosmetic issues -- for example, it produced
84
+ parentheses even in some cases where they were not needed, to
85
+ avoid having to reason about precedence.
86
+
87
+ Other derivatives of Armin's code are floating around, and typically
88
+ have fixes for a few corner cases that happened to be noticed by the
89
+ maintainers, but most of them have not been tested as thoroughly as
90
+ astor. One exception may be the version of codegen
91
+ `maintained at github by CensoredUsername`__. This has been tested
92
+ to work properly on Python 2.7 using astor's test suite, and, as it
93
+ is a single source file, it may be easier to drop into some applications
94
+ that do not require astor's other features or Python 3.x compatibility.
95
+
96
+ __ https://github.com/CensoredUsername/codegen
97
+
98
+
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/RECORD ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ astor-0.8.1.dist-info/AUTHORS,sha256=dy5MQIMINxY79YbaRR19C_CNAgHe3tcuvESs7ypxKQc,679
2
+ astor-0.8.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
3
+ astor-0.8.1.dist-info/LICENSE,sha256=zkHq_C78AY2cfJahx3lmgkbHfbEaE544ifNH9GSmG50,1554
4
+ astor-0.8.1.dist-info/METADATA,sha256=0nH_-dzD0tPZUB4Hs5o-OOEuId9lteVELQPI5hG0oKo,4235
5
+ astor-0.8.1.dist-info/RECORD,,
6
+ astor-0.8.1.dist-info/WHEEL,sha256=8zNYZbwQSXoB9IfXOjPfeNwvAsALAjffgk27FqvCWbo,110
7
+ astor-0.8.1.dist-info/top_level.txt,sha256=M5xfrbiL9-EIlOb1h2T8s6gFbV3b9AbwgI0ARzaRyaY,6
8
+ astor-0.8.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
9
+ astor/VERSION,sha256=qvZyHcN8QLQjOsz8CB8ld2_zvR0qS51c6nYNHCz4ZmU,6
10
+ astor/__init__.py,sha256=C9rmH4v9K7pkIk3eDuVRhqO5wULt3B42copNJsEw8rw,2291
11
+ astor/__pycache__/__init__.cpython-311.pyc,,
12
+ astor/__pycache__/code_gen.cpython-311.pyc,,
13
+ astor/__pycache__/codegen.cpython-311.pyc,,
14
+ astor/__pycache__/file_util.cpython-311.pyc,,
15
+ astor/__pycache__/node_util.cpython-311.pyc,,
16
+ astor/__pycache__/op_util.cpython-311.pyc,,
17
+ astor/__pycache__/rtrip.cpython-311.pyc,,
18
+ astor/__pycache__/source_repr.cpython-311.pyc,,
19
+ astor/__pycache__/string_repr.cpython-311.pyc,,
20
+ astor/__pycache__/tree_walk.cpython-311.pyc,,
21
+ astor/code_gen.py,sha256=0KAimfyV8pIPXQx6s_NyPSXRhAxMLWXbCPEQuCTpxac,32032
22
+ astor/codegen.py,sha256=lTqdJWMK4EAJ1wxDw2XR-MLyHJmvbV1_Q5QLj9naE_g,204
23
+ astor/file_util.py,sha256=BETsKYg8UiKoZNswRkirzPSZWgku41dRzZC7T5X3_F4,3268
24
+ astor/node_util.py,sha256=WEWMUMSfHtLwgx54nMkc2APLV573iOPhqPag4gIbhVQ,6542
25
+ astor/op_util.py,sha256=GGcgYqa3DFOAaoSt7TTu46VUhe1J13dO14-SQTRXRYI,3191
26
+ astor/rtrip.py,sha256=AlvQvsUuUZ8zxvRFpWF_Fsv4-NksPB23rvVkTrkvef8,6741
27
+ astor/source_repr.py,sha256=1lj4jakkrcGDRoo-BIRZDszQ8gukdeLR_fmvGqBrP-U,7373
28
+ astor/string_repr.py,sha256=YeC_DVeIJdPElqjgzzhPFheQsz_QjMEW_SLODFvEsIA,2917
29
+ astor/tree_walk.py,sha256=fJaw54GgTg4NTRJLVRl2XSnfFOG9GdjOUlI6ZChLOb8,6020
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/WHEEL ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.33.6)
3
+ Root-Is-Purelib: true
4
+ Tag: py2-none-any
5
+ Tag: py3-none-any
6
+
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ astor
.venv/lib/python3.11/site-packages/astor-0.8.1.dist-info/zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
.venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/METADATA ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: huggingface-hub
3
+ Version: 0.28.1
4
+ Summary: Client library to download and publish models, datasets and other repos on the huggingface.co hub
5
+ Home-page: https://github.com/huggingface/huggingface_hub
6
+ Author: Hugging Face, Inc.
7
+ Author-email: julien@huggingface.co
8
+ License: Apache
9
+ Keywords: model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models
10
+ Platform: UNKNOWN
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: OSI Approved :: Apache Software License
15
+ Classifier: Operating System :: OS Independent
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3 :: Only
18
+ Classifier: Programming Language :: Python :: 3.8
19
+ Classifier: Programming Language :: Python :: 3.9
20
+ Classifier: Programming Language :: Python :: 3.10
21
+ Classifier: Programming Language :: Python :: 3.11
22
+ Classifier: Programming Language :: Python :: 3.12
23
+ Classifier: Programming Language :: Python :: 3.13
24
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
25
+ Requires-Python: >=3.8.0
26
+ Description-Content-Type: text/markdown
27
+ License-File: LICENSE
28
+ Requires-Dist: filelock
29
+ Requires-Dist: fsspec>=2023.5.0
30
+ Requires-Dist: packaging>=20.9
31
+ Requires-Dist: pyyaml>=5.1
32
+ Requires-Dist: requests
33
+ Requires-Dist: tqdm>=4.42.1
34
+ Requires-Dist: typing-extensions>=3.7.4.3
35
+ Provides-Extra: all
36
+ Requires-Dist: InquirerPy==0.3.4; extra == "all"
37
+ Requires-Dist: aiohttp; extra == "all"
38
+ Requires-Dist: jedi; extra == "all"
39
+ Requires-Dist: Jinja2; extra == "all"
40
+ Requires-Dist: pytest<8.2.2,>=8.1.1; extra == "all"
41
+ Requires-Dist: pytest-cov; extra == "all"
42
+ Requires-Dist: pytest-env; extra == "all"
43
+ Requires-Dist: pytest-xdist; extra == "all"
44
+ Requires-Dist: pytest-vcr; extra == "all"
45
+ Requires-Dist: pytest-asyncio; extra == "all"
46
+ Requires-Dist: pytest-rerunfailures; extra == "all"
47
+ Requires-Dist: pytest-mock; extra == "all"
48
+ Requires-Dist: urllib3<2.0; extra == "all"
49
+ Requires-Dist: soundfile; extra == "all"
50
+ Requires-Dist: Pillow; extra == "all"
51
+ Requires-Dist: gradio>=4.0.0; extra == "all"
52
+ Requires-Dist: numpy; extra == "all"
53
+ Requires-Dist: fastapi; extra == "all"
54
+ Requires-Dist: ruff>=0.9.0; extra == "all"
55
+ Requires-Dist: mypy==1.5.1; extra == "all"
56
+ Requires-Dist: libcst==1.4.0; extra == "all"
57
+ Requires-Dist: typing-extensions>=4.8.0; extra == "all"
58
+ Requires-Dist: types-PyYAML; extra == "all"
59
+ Requires-Dist: types-requests; extra == "all"
60
+ Requires-Dist: types-simplejson; extra == "all"
61
+ Requires-Dist: types-toml; extra == "all"
62
+ Requires-Dist: types-tqdm; extra == "all"
63
+ Requires-Dist: types-urllib3; extra == "all"
64
+ Provides-Extra: cli
65
+ Requires-Dist: InquirerPy==0.3.4; extra == "cli"
66
+ Provides-Extra: dev
67
+ Requires-Dist: InquirerPy==0.3.4; extra == "dev"
68
+ Requires-Dist: aiohttp; extra == "dev"
69
+ Requires-Dist: jedi; extra == "dev"
70
+ Requires-Dist: Jinja2; extra == "dev"
71
+ Requires-Dist: pytest<8.2.2,>=8.1.1; extra == "dev"
72
+ Requires-Dist: pytest-cov; extra == "dev"
73
+ Requires-Dist: pytest-env; extra == "dev"
74
+ Requires-Dist: pytest-xdist; extra == "dev"
75
+ Requires-Dist: pytest-vcr; extra == "dev"
76
+ Requires-Dist: pytest-asyncio; extra == "dev"
77
+ Requires-Dist: pytest-rerunfailures; extra == "dev"
78
+ Requires-Dist: pytest-mock; extra == "dev"
79
+ Requires-Dist: urllib3<2.0; extra == "dev"
80
+ Requires-Dist: soundfile; extra == "dev"
81
+ Requires-Dist: Pillow; extra == "dev"
82
+ Requires-Dist: gradio>=4.0.0; extra == "dev"
83
+ Requires-Dist: numpy; extra == "dev"
84
+ Requires-Dist: fastapi; extra == "dev"
85
+ Requires-Dist: ruff>=0.9.0; extra == "dev"
86
+ Requires-Dist: mypy==1.5.1; extra == "dev"
87
+ Requires-Dist: libcst==1.4.0; extra == "dev"
88
+ Requires-Dist: typing-extensions>=4.8.0; extra == "dev"
89
+ Requires-Dist: types-PyYAML; extra == "dev"
90
+ Requires-Dist: types-requests; extra == "dev"
91
+ Requires-Dist: types-simplejson; extra == "dev"
92
+ Requires-Dist: types-toml; extra == "dev"
93
+ Requires-Dist: types-tqdm; extra == "dev"
94
+ Requires-Dist: types-urllib3; extra == "dev"
95
+ Provides-Extra: fastai
96
+ Requires-Dist: toml; extra == "fastai"
97
+ Requires-Dist: fastai>=2.4; extra == "fastai"
98
+ Requires-Dist: fastcore>=1.3.27; extra == "fastai"
99
+ Provides-Extra: hf_transfer
100
+ Requires-Dist: hf-transfer>=0.1.4; extra == "hf-transfer"
101
+ Provides-Extra: inference
102
+ Requires-Dist: aiohttp; extra == "inference"
103
+ Provides-Extra: quality
104
+ Requires-Dist: ruff>=0.9.0; extra == "quality"
105
+ Requires-Dist: mypy==1.5.1; extra == "quality"
106
+ Requires-Dist: libcst==1.4.0; extra == "quality"
107
+ Provides-Extra: tensorflow
108
+ Requires-Dist: tensorflow; extra == "tensorflow"
109
+ Requires-Dist: pydot; extra == "tensorflow"
110
+ Requires-Dist: graphviz; extra == "tensorflow"
111
+ Provides-Extra: tensorflow-testing
112
+ Requires-Dist: tensorflow; extra == "tensorflow-testing"
113
+ Requires-Dist: keras<3.0; extra == "tensorflow-testing"
114
+ Provides-Extra: testing
115
+ Requires-Dist: InquirerPy==0.3.4; extra == "testing"
116
+ Requires-Dist: aiohttp; extra == "testing"
117
+ Requires-Dist: jedi; extra == "testing"
118
+ Requires-Dist: Jinja2; extra == "testing"
119
+ Requires-Dist: pytest<8.2.2,>=8.1.1; extra == "testing"
120
+ Requires-Dist: pytest-cov; extra == "testing"
121
+ Requires-Dist: pytest-env; extra == "testing"
122
+ Requires-Dist: pytest-xdist; extra == "testing"
123
+ Requires-Dist: pytest-vcr; extra == "testing"
124
+ Requires-Dist: pytest-asyncio; extra == "testing"
125
+ Requires-Dist: pytest-rerunfailures; extra == "testing"
126
+ Requires-Dist: pytest-mock; extra == "testing"
127
+ Requires-Dist: urllib3<2.0; extra == "testing"
128
+ Requires-Dist: soundfile; extra == "testing"
129
+ Requires-Dist: Pillow; extra == "testing"
130
+ Requires-Dist: gradio>=4.0.0; extra == "testing"
131
+ Requires-Dist: numpy; extra == "testing"
132
+ Requires-Dist: fastapi; extra == "testing"
133
+ Provides-Extra: torch
134
+ Requires-Dist: torch; extra == "torch"
135
+ Requires-Dist: safetensors[torch]; extra == "torch"
136
+ Provides-Extra: typing
137
+ Requires-Dist: typing-extensions>=4.8.0; extra == "typing"
138
+ Requires-Dist: types-PyYAML; extra == "typing"
139
+ Requires-Dist: types-requests; extra == "typing"
140
+ Requires-Dist: types-simplejson; extra == "typing"
141
+ Requires-Dist: types-toml; extra == "typing"
142
+ Requires-Dist: types-tqdm; extra == "typing"
143
+ Requires-Dist: types-urllib3; extra == "typing"
144
+
145
+ <p align="center">
146
+ <picture>
147
+ <source media="(prefers-color-scheme: dark)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub-dark.svg">
148
+ <source media="(prefers-color-scheme: light)" srcset="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg">
149
+ <img alt="huggingface_hub library logo" src="https://huggingface.co/datasets/huggingface/documentation-images/raw/main/huggingface_hub.svg" width="352" height="59" style="max-width: 100%;">
150
+ </picture>
151
+ <br/>
152
+ <br/>
153
+ </p>
154
+
155
+ <p align="center">
156
+ <i>The official Python client for the Huggingface Hub.</i>
157
+ </p>
158
+
159
+ <p align="center">
160
+ <a href="https://huggingface.co/docs/huggingface_hub/en/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/huggingface_hub/index.svg?down_color=red&down_message=offline&up_message=online&label=doc"></a>
161
+ <a href="https://github.com/huggingface/huggingface_hub/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/huggingface_hub.svg"></a>
162
+ <a href="https://github.com/huggingface/huggingface_hub"><img alt="PyPi version" src="https://img.shields.io/pypi/pyversions/huggingface_hub.svg"></a>
163
+ <a href="https://pypi.org/project/huggingface-hub"><img alt="PyPI - Downloads" src="https://img.shields.io/pypi/dm/huggingface_hub"></a>
164
+ <a href="https://codecov.io/gh/huggingface/huggingface_hub"><img alt="Code coverage" src="https://codecov.io/gh/huggingface/huggingface_hub/branch/main/graph/badge.svg?token=RXP95LE2XL"></a>
165
+ </p>
166
+
167
+ <h4 align="center">
168
+ <p>
169
+ <b>English</b> |
170
+ <a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_de.md">Deutsch</a> |
171
+ <a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_hi.md">हिंदी</a> |
172
+ <a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_ko.md">한국어</a> |
173
+ <a href="https://github.com/huggingface/huggingface_hub/blob/main/i18n/README_cn.md">中文(简体)</a>
174
+ <p>
175
+ </h4>
176
+
177
+ ---
178
+
179
+ **Documentation**: <a href="https://hf.co/docs/huggingface_hub" target="_blank">https://hf.co/docs/huggingface_hub</a>
180
+
181
+ **Source Code**: <a href="https://github.com/huggingface/huggingface_hub" target="_blank">https://github.com/huggingface/huggingface_hub</a>
182
+
183
+ ---
184
+
185
+ ## Welcome to the huggingface_hub library
186
+
187
+ The `huggingface_hub` library allows you to interact with the [Hugging Face Hub](https://huggingface.co/), a platform democratizing open-source Machine Learning for creators and collaborators. Discover pre-trained models and datasets for your projects or play with the thousands of machine learning apps hosted on the Hub. You can also create and share your own models, datasets and demos with the community. The `huggingface_hub` library provides a simple way to do all these things with Python.
188
+
189
+ ## Key features
190
+
191
+ - [Download files](https://huggingface.co/docs/huggingface_hub/en/guides/download) from the Hub.
192
+ - [Upload files](https://huggingface.co/docs/huggingface_hub/en/guides/upload) to the Hub.
193
+ - [Manage your repositories](https://huggingface.co/docs/huggingface_hub/en/guides/repository).
194
+ - [Run Inference](https://huggingface.co/docs/huggingface_hub/en/guides/inference) on deployed models.
195
+ - [Search](https://huggingface.co/docs/huggingface_hub/en/guides/search) for models, datasets and Spaces.
196
+ - [Share Model Cards](https://huggingface.co/docs/huggingface_hub/en/guides/model-cards) to document your models.
197
+ - [Engage with the community](https://huggingface.co/docs/huggingface_hub/en/guides/community) through PRs and comments.
198
+
199
+ ## Installation
200
+
201
+ Install the `huggingface_hub` package with [pip](https://pypi.org/project/huggingface-hub/):
202
+
203
+ ```bash
204
+ pip install huggingface_hub
205
+ ```
206
+
207
+ If you prefer, you can also install it with [conda](https://huggingface.co/docs/huggingface_hub/en/installation#install-with-conda).
208
+
209
+ In order to keep the package minimal by default, `huggingface_hub` comes with optional dependencies useful for some use cases. For example, if you want have a complete experience for Inference, run:
210
+
211
+ ```bash
212
+ pip install huggingface_hub[inference]
213
+ ```
214
+
215
+ To learn more installation and optional dependencies, check out the [installation guide](https://huggingface.co/docs/huggingface_hub/en/installation).
216
+
217
+ ## Quick start
218
+
219
+ ### Download files
220
+
221
+ Download a single file
222
+
223
+ ```py
224
+ from huggingface_hub import hf_hub_download
225
+
226
+ hf_hub_download(repo_id="tiiuae/falcon-7b-instruct", filename="config.json")
227
+ ```
228
+
229
+ Or an entire repository
230
+
231
+ ```py
232
+ from huggingface_hub import snapshot_download
233
+
234
+ snapshot_download("stabilityai/stable-diffusion-2-1")
235
+ ```
236
+
237
+ Files will be downloaded in a local cache folder. More details in [this guide](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache).
238
+
239
+ ### Login
240
+
241
+ The Hugging Face Hub uses tokens to authenticate applications (see [docs](https://huggingface.co/docs/hub/security-tokens)). To log in your machine, run the following CLI:
242
+
243
+ ```bash
244
+ huggingface-cli login
245
+ # or using an environment variable
246
+ huggingface-cli login --token $HUGGINGFACE_TOKEN
247
+ ```
248
+
249
+ ### Create a repository
250
+
251
+ ```py
252
+ from huggingface_hub import create_repo
253
+
254
+ create_repo(repo_id="super-cool-model")
255
+ ```
256
+
257
+ ### Upload files
258
+
259
+ Upload a single file
260
+
261
+ ```py
262
+ from huggingface_hub import upload_file
263
+
264
+ upload_file(
265
+ path_or_fileobj="/home/lysandre/dummy-test/README.md",
266
+ path_in_repo="README.md",
267
+ repo_id="lysandre/test-model",
268
+ )
269
+ ```
270
+
271
+ Or an entire folder
272
+
273
+ ```py
274
+ from huggingface_hub import upload_folder
275
+
276
+ upload_folder(
277
+ folder_path="/path/to/local/space",
278
+ repo_id="username/my-cool-space",
279
+ repo_type="space",
280
+ )
281
+ ```
282
+
283
+ For details in the [upload guide](https://huggingface.co/docs/huggingface_hub/en/guides/upload).
284
+
285
+ ## Integrating to the Hub.
286
+
287
+ We're partnering with cool open source ML libraries to provide free model hosting and versioning. You can find the existing integrations [here](https://huggingface.co/docs/hub/libraries).
288
+
289
+ The advantages are:
290
+
291
+ - Free model or dataset hosting for libraries and their users.
292
+ - Built-in file versioning, even with very large files, thanks to a git-based approach.
293
+ - Serverless inference API for all models publicly available.
294
+ - In-browser widgets to play with the uploaded models.
295
+ - Anyone can upload a new model for your library, they just need to add the corresponding tag for the model to be discoverable.
296
+ - Fast downloads! We use Cloudfront (a CDN) to geo-replicate downloads so they're blazing fast from anywhere on the globe.
297
+ - Usage stats and more features to come.
298
+
299
+ If you would like to integrate your library, feel free to open an issue to begin the discussion. We wrote a [step-by-step guide](https://huggingface.co/docs/hub/adding-a-library) with ❤️ showing how to do this integration.
300
+
301
+ ## Contributions (feature requests, bugs, etc.) are super welcome 💙💚💛💜🧡❤️
302
+
303
+ Everyone is welcome to contribute, and we value everybody's contribution. Code is not the only way to help the community.
304
+ Answering questions, helping others, reaching out and improving the documentations are immensely valuable to the community.
305
+ We wrote a [contribution guide](https://github.com/huggingface/huggingface_hub/blob/main/CONTRIBUTING.md) to summarize
306
+ how to get started to contribute to this repository.
307
+
308
+
.venv/lib/python3.11/site-packages/huggingface_hub-0.28.1.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ huggingface_hub
.venv/lib/python3.11/site-packages/nvidia_cuda_cupti_cu12-12.4.127.dist-info/METADATA ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: nvidia-cuda-cupti-cu12
3
+ Version: 12.4.127
4
+ Summary: CUDA profiling tools runtime libs.
5
+ Home-page: https://developer.nvidia.com/cuda-zone
6
+ Author: Nvidia CUDA Installer Team
7
+ Author-email: cuda_installer@nvidia.com
8
+ License: NVIDIA Proprietary Software
9
+ Keywords: cuda,nvidia,runtime,machine learning,deep learning
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Education
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: License :: Other/Proprietary License
15
+ Classifier: Natural Language :: English
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.5
18
+ Classifier: Programming Language :: Python :: 3.6
19
+ Classifier: Programming Language :: Python :: 3.7
20
+ Classifier: Programming Language :: Python :: 3.8
21
+ Classifier: Programming Language :: Python :: 3.9
22
+ Classifier: Programming Language :: Python :: 3.10
23
+ Classifier: Programming Language :: Python :: 3.11
24
+ Classifier: Programming Language :: Python :: 3 :: Only
25
+ Classifier: Topic :: Scientific/Engineering
26
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
27
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
28
+ Classifier: Topic :: Software Development
29
+ Classifier: Topic :: Software Development :: Libraries
30
+ Classifier: Operating System :: Microsoft :: Windows
31
+ Classifier: Operating System :: POSIX :: Linux
32
+ Requires-Python: >=3
33
+ License-File: License.txt
34
+
35
+ Provides libraries to enable third party tools using GPU profiling APIs.
.venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2015 Radim Řehůřek
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
.venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/METADATA ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: smart-open
3
+ Version: 7.1.0
4
+ Summary: Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)
5
+ Home-page: https://github.com/piskvorky/smart_open
6
+ Author: Radim Rehurek
7
+ Author-email: me@radimrehurek.com
8
+ Maintainer: Radim Rehurek
9
+ Maintainer-email: me@radimrehurek.com
10
+ License: MIT
11
+ Download-URL: http://pypi.python.org/pypi/smart_open
12
+ Keywords: file streaming,s3,hdfs,gcs,azure blob storage
13
+ Platform: any
14
+ Classifier: Development Status :: 5 - Production/Stable
15
+ Classifier: Environment :: Console
16
+ Classifier: Intended Audience :: Developers
17
+ Classifier: License :: OSI Approved :: MIT License
18
+ Classifier: Operating System :: OS Independent
19
+ Classifier: Programming Language :: Python :: 3.7
20
+ Classifier: Programming Language :: Python :: 3.8
21
+ Classifier: Programming Language :: Python :: 3.9
22
+ Classifier: Programming Language :: Python :: 3.10
23
+ Classifier: Programming Language :: Python :: 3.11
24
+ Classifier: Programming Language :: Python :: 3.12
25
+ Classifier: Programming Language :: Python :: 3.13
26
+ Classifier: Topic :: System :: Distributed Computing
27
+ Classifier: Topic :: Database :: Front-Ends
28
+ Requires-Python: >=3.7,<4.0
29
+ Requires-Dist: wrapt
30
+ Provides-Extra: all
31
+ Requires-Dist: boto3; extra == "all"
32
+ Requires-Dist: google-cloud-storage>=2.6.0; extra == "all"
33
+ Requires-Dist: azure-storage-blob; extra == "all"
34
+ Requires-Dist: azure-common; extra == "all"
35
+ Requires-Dist: azure-core; extra == "all"
36
+ Requires-Dist: requests; extra == "all"
37
+ Requires-Dist: paramiko; extra == "all"
38
+ Requires-Dist: zstandard; extra == "all"
39
+ Provides-Extra: azure
40
+ Requires-Dist: azure-storage-blob; extra == "azure"
41
+ Requires-Dist: azure-common; extra == "azure"
42
+ Requires-Dist: azure-core; extra == "azure"
43
+ Provides-Extra: gcs
44
+ Requires-Dist: google-cloud-storage>=2.6.0; extra == "gcs"
45
+ Provides-Extra: http
46
+ Requires-Dist: requests; extra == "http"
47
+ Provides-Extra: s3
48
+ Requires-Dist: boto3; extra == "s3"
49
+ Provides-Extra: ssh
50
+ Requires-Dist: paramiko; extra == "ssh"
51
+ Provides-Extra: test
52
+ Requires-Dist: boto3; extra == "test"
53
+ Requires-Dist: google-cloud-storage>=2.6.0; extra == "test"
54
+ Requires-Dist: azure-storage-blob; extra == "test"
55
+ Requires-Dist: azure-common; extra == "test"
56
+ Requires-Dist: azure-core; extra == "test"
57
+ Requires-Dist: requests; extra == "test"
58
+ Requires-Dist: paramiko; extra == "test"
59
+ Requires-Dist: zstandard; extra == "test"
60
+ Requires-Dist: moto[server]; extra == "test"
61
+ Requires-Dist: responses; extra == "test"
62
+ Requires-Dist: pytest; extra == "test"
63
+ Requires-Dist: pytest-rerunfailures; extra == "test"
64
+ Requires-Dist: pytest-benchmark; extra == "test"
65
+ Requires-Dist: awscli; extra == "test"
66
+ Requires-Dist: pyopenssl; extra == "test"
67
+ Requires-Dist: numpy; extra == "test"
68
+ Provides-Extra: webhdfs
69
+ Requires-Dist: requests; extra == "webhdfs"
70
+ Provides-Extra: zst
71
+ Requires-Dist: zstandard; extra == "zst"
72
+
73
+ ======================================================
74
+ smart_open — utils for streaming large files in Python
75
+ ======================================================
76
+
77
+
78
+ |License|_ |GHA|_ |Coveralls|_ |Downloads|_
79
+
80
+ .. |License| image:: https://img.shields.io/pypi/l/smart_open.svg
81
+ .. |GHA| image:: https://github.com/RaRe-Technologies/smart_open/workflows/Test/badge.svg
82
+ .. |Coveralls| image:: https://coveralls.io/repos/github/RaRe-Technologies/smart_open/badge.svg?branch=develop
83
+ .. |Downloads| image:: https://pepy.tech/badge/smart-open/month
84
+ .. _License: https://github.com/RaRe-Technologies/smart_open/blob/master/LICENSE
85
+ .. _GHA: https://github.com/RaRe-Technologies/smart_open/actions?query=workflow%3ATest
86
+ .. _Coveralls: https://coveralls.io/github/RaRe-Technologies/smart_open?branch=HEAD
87
+ .. _Downloads: https://pypi.org/project/smart-open/
88
+
89
+
90
+ What?
91
+ =====
92
+
93
+ ``smart_open`` is a Python 3 library for **efficient streaming of very large files** from/to storages such as S3, GCS, Azure Blob Storage, HDFS, WebHDFS, HTTP, HTTPS, SFTP, or local filesystem. It supports transparent, on-the-fly (de-)compression for a variety of different formats.
94
+
95
+ ``smart_open`` is a drop-in replacement for Python's built-in ``open()``: it can do anything ``open`` can (100% compatible, falls back to native ``open`` wherever possible), plus lots of nifty extra stuff on top.
96
+
97
+ **Python 2.7 is no longer supported. If you need Python 2.7, please use** `smart_open 1.10.1 <https://github.com/RaRe-Technologies/smart_open/releases/tag/1.10.0>`_, **the last version to support Python 2.**
98
+
99
+ Why?
100
+ ====
101
+
102
+ Working with large remote files, for example using Amazon's `boto3 <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>`_ Python library, is a pain.
103
+ ``boto3``'s ``Object.upload_fileobj()`` and ``Object.download_fileobj()`` methods require gotcha-prone boilerplate to use successfully, such as constructing file-like object wrappers.
104
+ ``smart_open`` shields you from that. It builds on boto3 and other remote storage libraries, but offers a **clean unified Pythonic API**. The result is less code for you to write and fewer bugs to make.
105
+
106
+
107
+ How?
108
+ =====
109
+
110
+ ``smart_open`` is well-tested, well-documented, and has a simple Pythonic API:
111
+
112
+
113
+ .. _doctools_before_examples:
114
+
115
+ .. code-block:: python
116
+
117
+ >>> from smart_open import open
118
+ >>>
119
+ >>> # stream lines from an S3 object
120
+ >>> for line in open('s3://commoncrawl/robots.txt'):
121
+ ... print(repr(line))
122
+ ... break
123
+ 'User-Agent: *\n'
124
+
125
+ >>> # stream from/to compressed files, with transparent (de)compression:
126
+ >>> for line in open('smart_open/tests/test_data/1984.txt.gz', encoding='utf-8'):
127
+ ... print(repr(line))
128
+ 'It was a bright cold day in April, and the clocks were striking thirteen.\n'
129
+ 'Winston Smith, his chin nuzzled into his breast in an effort to escape the vile\n'
130
+ 'wind, slipped quickly through the glass doors of Victory Mansions, though not\n'
131
+ 'quickly enough to prevent a swirl of gritty dust from entering along with him.\n'
132
+
133
+ >>> # can use context managers too:
134
+ >>> with open('smart_open/tests/test_data/1984.txt.gz') as fin:
135
+ ... with open('smart_open/tests/test_data/1984.txt.bz2', 'w') as fout:
136
+ ... for line in fin:
137
+ ... fout.write(line)
138
+ 74
139
+ 80
140
+ 78
141
+ 79
142
+
143
+ >>> # can use any IOBase operations, like seek
144
+ >>> with open('s3://commoncrawl/robots.txt', 'rb') as fin:
145
+ ... for line in fin:
146
+ ... print(repr(line.decode('utf-8')))
147
+ ... break
148
+ ... offset = fin.seek(0) # seek to the beginning
149
+ ... print(fin.read(4))
150
+ 'User-Agent: *\n'
151
+ b'User'
152
+
153
+ >>> # stream from HTTP
154
+ >>> for line in open('http://example.com/index.html'):
155
+ ... print(repr(line))
156
+ ... break
157
+ '<!doctype html>\n'
158
+
159
+ .. _doctools_after_examples:
160
+
161
+ Other examples of URLs that ``smart_open`` accepts::
162
+
163
+ s3://my_bucket/my_key
164
+ s3://my_key:my_secret@my_bucket/my_key
165
+ s3://my_key:my_secret@my_server:my_port@my_bucket/my_key
166
+ gs://my_bucket/my_blob
167
+ azure://my_bucket/my_blob
168
+ hdfs:///path/file
169
+ hdfs://path/file
170
+ webhdfs://host:port/path/file
171
+ ./local/path/file
172
+ ~/local/path/file
173
+ local/path/file
174
+ ./local/path/file.gz
175
+ file:///home/user/file
176
+ file:///home/user/file.bz2
177
+ [ssh|scp|sftp]://username@host//path/file
178
+ [ssh|scp|sftp]://username@host/path/file
179
+ [ssh|scp|sftp]://username:password@host/path/file
180
+
181
+
182
+ Documentation
183
+ =============
184
+
185
+ Installation
186
+ ------------
187
+
188
+ ``smart_open`` supports a wide range of storage solutions, including AWS S3, Google Cloud and Azure.
189
+ Each individual solution has its own dependencies.
190
+ By default, ``smart_open`` does not install any dependencies, in order to keep the installation size small.
191
+ You can install these dependencies explicitly using::
192
+
193
+ pip install smart_open[azure] # Install Azure deps
194
+ pip install smart_open[gcs] # Install GCS deps
195
+ pip install smart_open[s3] # Install S3 deps
196
+
197
+ Or, if you don't mind installing a large number of third party libraries, you can install all dependencies using::
198
+
199
+ pip install smart_open[all]
200
+
201
+ Be warned that this option increases the installation size significantly, e.g. over 100MB.
202
+
203
+ If you're upgrading from ``smart_open`` versions 2.x and below, please check out the `Migration Guide <MIGRATING_FROM_OLDER_VERSIONS.rst>`_.
204
+
205
+ Built-in help
206
+ -------------
207
+
208
+ For detailed API info, see the online help:
209
+
210
+ .. code-block:: python
211
+
212
+ help('smart_open')
213
+
214
+ or click `here <https://github.com/RaRe-Technologies/smart_open/blob/master/help.txt>`__ to view the help in your browser.
215
+
216
+ More examples
217
+ -------------
218
+
219
+ For the sake of simplicity, the examples below assume you have all the dependencies installed, i.e. you have done::
220
+
221
+ pip install smart_open[all]
222
+
223
+ .. code-block:: python
224
+
225
+ >>> import os, boto3
226
+ >>> from smart_open import open
227
+ >>>
228
+ >>> # stream content *into* S3 (write mode) using a custom session
229
+ >>> session = boto3.Session(
230
+ ... aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
231
+ ... aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
232
+ ... )
233
+ >>> url = 's3://smart-open-py37-benchmark-results/test.txt'
234
+ >>> with open(url, 'wb', transport_params={'client': session.client('s3')}) as fout:
235
+ ... bytes_written = fout.write(b'hello world!')
236
+ ... print(bytes_written)
237
+ 12
238
+
239
+ .. code-block:: python
240
+
241
+ # stream from HDFS
242
+ for line in open('hdfs://user/hadoop/my_file.txt', encoding='utf8'):
243
+ print(line)
244
+
245
+ # stream from WebHDFS
246
+ for line in open('webhdfs://host:port/user/hadoop/my_file.txt'):
247
+ print(line)
248
+
249
+ # stream content *into* HDFS (write mode):
250
+ with open('hdfs://host:port/user/hadoop/my_file.txt', 'wb') as fout:
251
+ fout.write(b'hello world')
252
+
253
+ # stream content *into* WebHDFS (write mode):
254
+ with open('webhdfs://host:port/user/hadoop/my_file.txt', 'wb') as fout:
255
+ fout.write(b'hello world')
256
+
257
+ # stream from a completely custom s3 server, like s3proxy:
258
+ for line in open('s3u://user:secret@host:port@mybucket/mykey.txt'):
259
+ print(line)
260
+
261
+ # Stream to Digital Ocean Spaces bucket providing credentials from boto3 profile
262
+ session = boto3.Session(profile_name='digitalocean')
263
+ client = session.client('s3', endpoint_url='https://ams3.digitaloceanspaces.com')
264
+ transport_params = {'client': client}
265
+ with open('s3://bucket/key.txt', 'wb', transport_params=transport_params) as fout:
266
+ fout.write(b'here we stand')
267
+
268
+ # stream from GCS
269
+ for line in open('gs://my_bucket/my_file.txt'):
270
+ print(line)
271
+
272
+ # stream content *into* GCS (write mode):
273
+ with open('gs://my_bucket/my_file.txt', 'wb') as fout:
274
+ fout.write(b'hello world')
275
+
276
+ # stream from Azure Blob Storage
277
+ connect_str = os.environ['AZURE_STORAGE_CONNECTION_STRING']
278
+ transport_params = {
279
+ 'client': azure.storage.blob.BlobServiceClient.from_connection_string(connect_str),
280
+ }
281
+ for line in open('azure://mycontainer/myfile.txt', transport_params=transport_params):
282
+ print(line)
283
+
284
+ # stream content *into* Azure Blob Storage (write mode):
285
+ connect_str = os.environ['AZURE_STORAGE_CONNECTION_STRING']
286
+ transport_params = {
287
+ 'client': azure.storage.blob.BlobServiceClient.from_connection_string(connect_str),
288
+ }
289
+ with open('azure://mycontainer/my_file.txt', 'wb', transport_params=transport_params) as fout:
290
+ fout.write(b'hello world')
291
+
292
+ Compression Handling
293
+ --------------------
294
+
295
+ The top-level `compression` parameter controls compression/decompression behavior when reading and writing.
296
+ The supported values for this parameter are:
297
+
298
+ - ``infer_from_extension`` (default behavior)
299
+ - ``disable``
300
+ - ``.gz``
301
+ - ``.bz2``
302
+ - ``.zst``
303
+
304
+ By default, ``smart_open`` determines the compression algorithm to use based on the file extension.
305
+
306
+ .. code-block:: python
307
+
308
+ >>> from smart_open import open, register_compressor
309
+ >>> with open('smart_open/tests/test_data/1984.txt.gz') as fin:
310
+ ... print(fin.read(32))
311
+ It was a bright cold day in Apri
312
+
313
+ You can override this behavior to either disable compression, or explicitly specify the algorithm to use.
314
+ To disable compression:
315
+
316
+ .. code-block:: python
317
+
318
+ >>> from smart_open import open, register_compressor
319
+ >>> with open('smart_open/tests/test_data/1984.txt.gz', 'rb', compression='disable') as fin:
320
+ ... print(fin.read(32))
321
+ b'\x1f\x8b\x08\x08\x85F\x94\\\x00\x031984.txt\x005\x8f=r\xc3@\x08\x85{\x9d\xe2\x1d@'
322
+
323
+
324
+ To specify the algorithm explicitly (e.g. for non-standard file extensions):
325
+
326
+ .. code-block:: python
327
+
328
+ >>> from smart_open import open, register_compressor
329
+ >>> with open('smart_open/tests/test_data/1984.txt.gzip', compression='.gz') as fin:
330
+ ... print(fin.read(32))
331
+ It was a bright cold day in Apri
332
+
333
+ You can also easily add support for other file extensions and compression formats.
334
+ For example, to open xz-compressed files:
335
+
336
+ .. code-block:: python
337
+
338
+ >>> import lzma, os
339
+ >>> from smart_open import open, register_compressor
340
+
341
+ >>> def _handle_xz(file_obj, mode):
342
+ ... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
343
+
344
+ >>> register_compressor('.xz', _handle_xz)
345
+
346
+ >>> with open('smart_open/tests/test_data/1984.txt.xz') as fin:
347
+ ... print(fin.read(32))
348
+ It was a bright cold day in Apri
349
+
350
+ ``lzma`` is in the standard library in Python 3.3 and greater.
351
+ For 2.7, use `backports.lzma`_.
352
+
353
+ .. _backports.lzma: https://pypi.org/project/backports.lzma/
354
+
355
+ Transport-specific Options
356
+ --------------------------
357
+
358
+ ``smart_open`` supports a wide range of transport options out of the box, including:
359
+
360
+ - S3
361
+ - HTTP, HTTPS (read-only)
362
+ - SSH, SCP and SFTP
363
+ - WebHDFS
364
+ - GCS
365
+ - Azure Blob Storage
366
+
367
+ Each option involves setting up its own set of parameters.
368
+ For example, for accessing S3, you often need to set up authentication, like API keys or a profile name.
369
+ ``smart_open``'s ``open`` function accepts a keyword argument ``transport_params`` which accepts additional parameters for the transport layer.
370
+ Here are some examples of using this parameter:
371
+
372
+ .. code-block:: python
373
+
374
+ >>> import boto3
375
+ >>> fin = open('s3://commoncrawl/robots.txt', transport_params=dict(client=boto3.client('s3')))
376
+ >>> fin = open('s3://commoncrawl/robots.txt', transport_params=dict(buffer_size=1024))
377
+
378
+ For the full list of keyword arguments supported by each transport option, see the documentation:
379
+
380
+ .. code-block:: python
381
+
382
+ help('smart_open.open')
383
+
384
+ S3 Credentials
385
+ --------------
386
+
387
+ ``smart_open`` uses the ``boto3`` library to talk to S3.
388
+ ``boto3`` has several `mechanisms <https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html>`__ for determining the credentials to use.
389
+ By default, ``smart_open`` will defer to ``boto3`` and let the latter take care of the credentials.
390
+ There are several ways to override this behavior.
391
+
392
+ The first is to pass a ``boto3.Client`` object as a transport parameter to the ``open`` function.
393
+ You can customize the credentials when constructing the session for the client.
394
+ ``smart_open`` will then use the session when talking to S3.
395
+
396
+ .. code-block:: python
397
+
398
+ session = boto3.Session(
399
+ aws_access_key_id=ACCESS_KEY,
400
+ aws_secret_access_key=SECRET_KEY,
401
+ aws_session_token=SESSION_TOKEN,
402
+ )
403
+ client = session.client('s3', endpoint_url=..., config=...)
404
+ fin = open('s3://bucket/key', transport_params={'client': client})
405
+
406
+ Your second option is to specify the credentials within the S3 URL itself:
407
+
408
+ .. code-block:: python
409
+
410
+ fin = open('s3://aws_access_key_id:aws_secret_access_key@bucket/key', ...)
411
+
412
+ *Important*: The two methods above are **mutually exclusive**. If you pass an AWS client *and* the URL contains credentials, ``smart_open`` will ignore the latter.
413
+
414
+ *Important*: ``smart_open`` ignores configuration files from the older ``boto`` library.
415
+ Port your old ``boto`` settings to ``boto3`` in order to use them with ``smart_open``.
416
+
417
+ S3 Advanced Usage
418
+ -----------------
419
+
420
+ Additional keyword arguments can be propagated to the boto3 methods that are used by ``smart_open`` under the hood using the ``client_kwargs`` transport parameter.
421
+
422
+ For instance, to upload a blob with Metadata, ACL, StorageClass, these keyword arguments can be passed to ``create_multipart_upload`` (`docs <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.create_multipart_upload>`__).
423
+
424
+ .. code-block:: python
425
+
426
+ kwargs = {'Metadata': {'version': 2}, 'ACL': 'authenticated-read', 'StorageClass': 'STANDARD_IA'}
427
+ fout = open('s3://bucket/key', 'wb', transport_params={'client_kwargs': {'S3.Client.create_multipart_upload': kwargs}})
428
+
429
+ Iterating Over an S3 Bucket's Contents
430
+ --------------------------------------
431
+
432
+ Since going over all (or select) keys in an S3 bucket is a very common operation, there's also an extra function ``smart_open.s3.iter_bucket()`` that does this efficiently, **processing the bucket keys in parallel** (using multiprocessing):
433
+
434
+ .. code-block:: python
435
+
436
+ >>> from smart_open import s3
437
+ >>> # we use workers=1 for reproducibility; you should use as many workers as you have cores
438
+ >>> bucket = 'silo-open-data'
439
+ >>> prefix = 'Official/annual/monthly_rain/'
440
+ >>> for key, content in s3.iter_bucket(bucket, prefix=prefix, accept_key=lambda key: '/201' in key, workers=1, key_limit=3):
441
+ ... print(key, round(len(content) / 2**20))
442
+ Official/annual/monthly_rain/2010.monthly_rain.nc 13
443
+ Official/annual/monthly_rain/2011.monthly_rain.nc 13
444
+ Official/annual/monthly_rain/2012.monthly_rain.nc 13
445
+
446
+ GCS Credentials
447
+ ---------------
448
+ ``smart_open`` uses the ``google-cloud-storage`` library to talk to GCS.
449
+ ``google-cloud-storage`` uses the ``google-cloud`` package under the hood to handle authentication.
450
+ There are several `options <https://googleapis.dev/python/google-api-core/latest/auth.html>`__ to provide
451
+ credentials.
452
+ By default, ``smart_open`` will defer to ``google-cloud-storage`` and let it take care of the credentials.
453
+
454
+ To override this behavior, pass a ``google.cloud.storage.Client`` object as a transport parameter to the ``open`` function.
455
+ You can `customize the credentials <https://googleapis.dev/python/storage/latest/client.html>`__
456
+ when constructing the client. ``smart_open`` will then use the client when talking to GCS. To follow allow with
457
+ the example below, `refer to Google's guide <https://cloud.google.com/storage/docs/reference/libraries#setting_up_authentication>`__
458
+ to setting up GCS authentication with a service account.
459
+
460
+ .. code-block:: python
461
+
462
+ import os
463
+ from google.cloud.storage import Client
464
+ service_account_path = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
465
+ client = Client.from_service_account_json(service_account_path)
466
+ fin = open('gs://gcp-public-data-landsat/index.csv.gz', transport_params=dict(client=client))
467
+
468
+ If you need more credential options, you can create an explicit ``google.auth.credentials.Credentials`` object
469
+ and pass it to the Client. To create an API token for use in the example below, refer to the
470
+ `GCS authentication guide <https://cloud.google.com/storage/docs/authentication#apiauth>`__.
471
+
472
+ .. code-block:: python
473
+
474
+ import os
475
+ from google.auth.credentials import Credentials
476
+ from google.cloud.storage import Client
477
+ token = os.environ['GOOGLE_API_TOKEN']
478
+ credentials = Credentials(token=token)
479
+ client = Client(credentials=credentials)
480
+ fin = open('gs://gcp-public-data-landsat/index.csv.gz', transport_params={'client': client})
481
+
482
+ GCS Advanced Usage
483
+ ------------------
484
+
485
+ Additional keyword arguments can be propagated to the GCS open method (`docs <https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.blob.Blob#google_cloud_storage_blob_Blob_open>`__), which is used by ``smart_open`` under the hood, using the ``blob_open_kwargs`` transport parameter.
486
+
487
+ Additionally keyword arguments can be propagated to the GCS ``get_blob`` method (`docs <https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.bucket.Bucket#google_cloud_storage_bucket_Bucket_get_blob>`__) when in a read-mode, using the ``get_blob_kwargs`` transport parameter.
488
+
489
+ Additional blob properties (`docs <https://cloud.google.com/python/docs/reference/storage/latest/google.cloud.storage.blob.Blob#properties>`__) can be set before an upload, as long as they are not read-only, using the ``blob_properties`` transport parameter.
490
+
491
+ .. code-block:: python
492
+
493
+ open_kwargs = {'predefined_acl': 'authenticated-read'}
494
+ properties = {'metadata': {'version': 2}, 'storage_class': 'COLDLINE'}
495
+ fout = open('gs://bucket/key', 'wb', transport_params={'blob_open_kwargs': open_kwargs, 'blob_properties': properties})
496
+
497
+ Azure Credentials
498
+ -----------------
499
+
500
+ ``smart_open`` uses the ``azure-storage-blob`` library to talk to Azure Blob Storage.
501
+ By default, ``smart_open`` will defer to ``azure-storage-blob`` and let it take care of the credentials.
502
+
503
+ Azure Blob Storage does not have any ways of inferring credentials therefore, passing a ``azure.storage.blob.BlobServiceClient``
504
+ object as a transport parameter to the ``open`` function is required.
505
+ You can `customize the credentials <https://docs.microsoft.com/en-us/azure/storage/common/storage-samples-python#authentication>`__
506
+ when constructing the client. ``smart_open`` will then use the client when talking to. To follow allow with
507
+ the example below, `refer to Azure's guide <https://docs.microsoft.com/en-us/azure/storage/blobs/storage-quickstart-blobs-python#copy-your-credentials-from-the-azure-portal>`__
508
+ to setting up authentication.
509
+
510
+ .. code-block:: python
511
+
512
+ import os
513
+ from azure.storage.blob import BlobServiceClient
514
+ azure_storage_connection_string = os.environ['AZURE_STORAGE_CONNECTION_STRING']
515
+ client = BlobServiceClient.from_connection_string(azure_storage_connection_string)
516
+ fin = open('azure://my_container/my_blob.txt', transport_params={'client': client})
517
+
518
+ If you need more credential options, refer to the
519
+ `Azure Storage authentication guide <https://docs.microsoft.com/en-us/azure/storage/common/storage-samples-python#authentication>`__.
520
+
521
+ Azure Advanced Usage
522
+ --------------------
523
+
524
+ Additional keyword arguments can be propagated to the ``commit_block_list`` method (`docs <https://azuresdkdocs.blob.core.windows.net/$web/python/azure-storage-blob/12.14.1/azure.storage.blob.html#azure.storage.blob.BlobClient.commit_block_list>`__), which is used by ``smart_open`` under the hood for uploads, using the ``blob_kwargs`` transport parameter.
525
+
526
+ .. code-block:: python
527
+
528
+ kwargs = {'metadata': {'version': 2}}
529
+ fout = open('azure://container/key', 'wb', transport_params={'blob_kwargs': kwargs})
530
+
531
+ Drop-in replacement of ``pathlib.Path.open``
532
+ --------------------------------------------
533
+
534
+ ``smart_open.open`` can also be used with ``Path`` objects.
535
+ The built-in `Path.open()` is not able to read text from compressed files, so use ``patch_pathlib`` to replace it with `smart_open.open()` instead.
536
+ This can be helpful when e.g. working with compressed files.
537
+
538
+ .. code-block:: python
539
+
540
+ >>> from pathlib import Path
541
+ >>> from smart_open.smart_open_lib import patch_pathlib
542
+ >>>
543
+ >>> _ = patch_pathlib() # replace `Path.open` with `smart_open.open`
544
+ >>>
545
+ >>> path = Path("smart_open/tests/test_data/crime-and-punishment.txt.gz")
546
+ >>>
547
+ >>> with path.open("r") as infile:
548
+ ... print(infile.readline()[:41])
549
+ В начале июля, в чрезвычайно жаркое время
550
+
551
+ How do I ...?
552
+ =============
553
+
554
+ See `this document <howto.md>`__.
555
+
556
+ Extending ``smart_open``
557
+ ========================
558
+
559
+ See `this document <extending.md>`__.
560
+
561
+ Testing ``smart_open``
562
+ ======================
563
+
564
+ ``smart_open`` comes with a comprehensive suite of unit tests.
565
+ Before you can run the test suite, install the test dependencies::
566
+
567
+ pip install -e .[test]
568
+
569
+ Now, you can run the unit tests::
570
+
571
+ pytest smart_open
572
+
573
+ The tests are also run automatically with `Travis CI <https://travis-ci.org/RaRe-Technologies/smart_open>`_ on every commit push & pull request.
574
+
575
+ Comments, bug reports
576
+ =====================
577
+
578
+ ``smart_open`` lives on `Github <https://github.com/RaRe-Technologies/smart_open>`_. You can file
579
+ issues or pull requests there. Suggestions, pull requests and improvements welcome!
580
+
581
+ ----------------
582
+
583
+ ``smart_open`` is open source software released under the `MIT license <https://github.com/piskvorky/smart_open/blob/master/LICENSE>`_.
584
+ Copyright (c) 2015-now `Radim Řehůřek <https://radimrehurek.com>`_.
585
+
586
+
.venv/lib/python3.11/site-packages/smart_open-7.1.0.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.45.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
.venv/lib/python3.11/site-packages/torchvision/__init__.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ from modulefinder import Module
4
+
5
+ import torch
6
+
7
+ # Don't re-order these, we need to load the _C extension (done when importing
8
+ # .extensions) before entering _meta_registrations.
9
+ from .extension import _HAS_OPS # usort:skip
10
+ from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip
11
+
12
+ try:
13
+ from .version import __version__ # noqa: F401
14
+ except ImportError:
15
+ pass
16
+
17
+
18
+ # Check if torchvision is being imported within the root folder
19
+ if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join(
20
+ os.path.realpath(os.getcwd()), "torchvision"
21
+ ):
22
+ message = (
23
+ "You are importing torchvision within its own root folder ({}). "
24
+ "This is not expected to work and may give errors. Please exit the "
25
+ "torchvision project source and relaunch your python interpreter."
26
+ )
27
+ warnings.warn(message.format(os.getcwd()))
28
+
29
+ _image_backend = "PIL"
30
+
31
+ _video_backend = "pyav"
32
+
33
+
34
+ def set_image_backend(backend):
35
+ """
36
+ Specifies the package used to load images.
37
+
38
+ Args:
39
+ backend (string): Name of the image backend. one of {'PIL', 'accimage'}.
40
+ The :mod:`accimage` package uses the Intel IPP library. It is
41
+ generally faster than PIL, but does not support as many operations.
42
+ """
43
+ global _image_backend
44
+ if backend not in ["PIL", "accimage"]:
45
+ raise ValueError(f"Invalid backend '{backend}'. Options are 'PIL' and 'accimage'")
46
+ _image_backend = backend
47
+
48
+
49
+ def get_image_backend():
50
+ """
51
+ Gets the name of the package used to load images
52
+ """
53
+ return _image_backend
54
+
55
+
56
+ def set_video_backend(backend):
57
+ """
58
+ Specifies the package used to decode videos.
59
+
60
+ Args:
61
+ backend (string): Name of the video backend. one of {'pyav', 'video_reader'}.
62
+ The :mod:`pyav` package uses the 3rd party PyAv library. It is a Pythonic
63
+ binding for the FFmpeg libraries.
64
+ The :mod:`video_reader` package includes a native C++ implementation on
65
+ top of FFMPEG libraries, and a python API of TorchScript custom operator.
66
+ It generally decodes faster than :mod:`pyav`, but is perhaps less robust.
67
+
68
+ .. note::
69
+ Building with FFMPEG is disabled by default in the latest `main`. If you want to use the 'video_reader'
70
+ backend, please compile torchvision from source.
71
+ """
72
+ global _video_backend
73
+ if backend not in ["pyav", "video_reader", "cuda"]:
74
+ raise ValueError("Invalid video backend '%s'. Options are 'pyav', 'video_reader' and 'cuda'" % backend)
75
+ if backend == "video_reader" and not io._HAS_CPU_VIDEO_DECODER:
76
+ # TODO: better messages
77
+ message = "video_reader video backend is not available. Please compile torchvision from source and try again"
78
+ raise RuntimeError(message)
79
+ elif backend == "cuda" and not io._HAS_GPU_VIDEO_DECODER:
80
+ # TODO: better messages
81
+ message = "cuda video backend is not available."
82
+ raise RuntimeError(message)
83
+ else:
84
+ _video_backend = backend
85
+
86
+
87
+ def get_video_backend():
88
+ """
89
+ Returns the currently active video backend used to decode videos.
90
+
91
+ Returns:
92
+ str: Name of the video backend. one of {'pyav', 'video_reader'}.
93
+ """
94
+
95
+ return _video_backend
96
+
97
+
98
+ def _is_tracing():
99
+ return torch._C._get_tracing_state()
100
+
101
+
102
+ def disable_beta_transforms_warning():
103
+ # Noop, only exists to avoid breaking existing code.
104
+ # See https://github.com/pytorch/vision/issues/7896
105
+ pass
.venv/lib/python3.11/site-packages/torchvision/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.53 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/__pycache__/_internally_replaced_utils.cpython-311.pyc ADDED
Binary file (2.44 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/__pycache__/_meta_registrations.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/__pycache__/extension.cpython-311.pyc ADDED
Binary file (4.05 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/__pycache__/utils.cpython-311.pyc ADDED
Binary file (37.9 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/__pycache__/version.cpython-311.pyc ADDED
Binary file (460 Bytes). View file
 
.venv/lib/python3.11/site-packages/torchvision/_internally_replaced_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.machinery
2
+ import os
3
+
4
+ from torch.hub import _get_torch_home
5
+
6
+
7
+ _HOME = os.path.join(_get_torch_home(), "datasets", "vision")
8
+ _USE_SHARDED_DATASETS = False
9
+
10
+
11
+ def _download_file_from_remote_location(fpath: str, url: str) -> None:
12
+ pass
13
+
14
+
15
+ def _is_remote_location_available() -> bool:
16
+ return False
17
+
18
+
19
+ try:
20
+ from torch.hub import load_state_dict_from_url # noqa: 401
21
+ except ImportError:
22
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401
23
+
24
+
25
+ def _get_extension_path(lib_name):
26
+
27
+ lib_dir = os.path.dirname(__file__)
28
+ if os.name == "nt":
29
+ # Register the main torchvision library location on the default DLL path
30
+ import ctypes
31
+
32
+ kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
33
+ with_load_library_flags = hasattr(kernel32, "AddDllDirectory")
34
+ prev_error_mode = kernel32.SetErrorMode(0x0001)
35
+
36
+ if with_load_library_flags:
37
+ kernel32.AddDllDirectory.restype = ctypes.c_void_p
38
+
39
+ os.add_dll_directory(lib_dir)
40
+
41
+ kernel32.SetErrorMode(prev_error_mode)
42
+
43
+ loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES)
44
+
45
+ extfinder = importlib.machinery.FileFinder(lib_dir, loader_details)
46
+ ext_specs = extfinder.find_spec(lib_name)
47
+ if ext_specs is None:
48
+ raise ImportError
49
+
50
+ return ext_specs.origin
.venv/lib/python3.11/site-packages/torchvision/_meta_registrations.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch._custom_ops
5
+ import torch.library
6
+
7
+ # Ensure that torch.ops.torchvision is visible
8
+ import torchvision.extension # noqa: F401
9
+
10
+
11
+ @functools.lru_cache(None)
12
+ def get_meta_lib():
13
+ return torch.library.Library("torchvision", "IMPL", "Meta")
14
+
15
+
16
+ def register_meta(op_name, overload_name="default"):
17
+ def wrapper(fn):
18
+ if torchvision.extension._has_ops():
19
+ get_meta_lib().impl(getattr(getattr(torch.ops.torchvision, op_name), overload_name), fn)
20
+ return fn
21
+
22
+ return wrapper
23
+
24
+
25
+ @register_meta("roi_align")
26
+ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
27
+ torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
28
+ torch._check(
29
+ input.dtype == rois.dtype,
30
+ lambda: (
31
+ "Expected tensor for input to have the same type as tensor for rois; "
32
+ f"but type {input.dtype} does not equal {rois.dtype}"
33
+ ),
34
+ )
35
+ num_rois = rois.size(0)
36
+ channels = input.size(1)
37
+ return input.new_empty((num_rois, channels, pooled_height, pooled_width))
38
+
39
+
40
+ @register_meta("_roi_align_backward")
41
+ def meta_roi_align_backward(
42
+ grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned
43
+ ):
44
+ torch._check(
45
+ grad.dtype == rois.dtype,
46
+ lambda: (
47
+ "Expected tensor for grad to have the same type as tensor for rois; "
48
+ f"but type {grad.dtype} does not equal {rois.dtype}"
49
+ ),
50
+ )
51
+ return grad.new_empty((batch_size, channels, height, width))
52
+
53
+
54
+ @register_meta("ps_roi_align")
55
+ def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
56
+ torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
57
+ torch._check(
58
+ input.dtype == rois.dtype,
59
+ lambda: (
60
+ "Expected tensor for input to have the same type as tensor for rois; "
61
+ f"but type {input.dtype} does not equal {rois.dtype}"
62
+ ),
63
+ )
64
+ channels = input.size(1)
65
+ torch._check(
66
+ channels % (pooled_height * pooled_width) == 0,
67
+ "input channels must be a multiple of pooling height * pooling width",
68
+ )
69
+
70
+ num_rois = rois.size(0)
71
+ out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
72
+ return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")
73
+
74
+
75
+ @register_meta("_ps_roi_align_backward")
76
+ def meta_ps_roi_align_backward(
77
+ grad,
78
+ rois,
79
+ channel_mapping,
80
+ spatial_scale,
81
+ pooled_height,
82
+ pooled_width,
83
+ sampling_ratio,
84
+ batch_size,
85
+ channels,
86
+ height,
87
+ width,
88
+ ):
89
+ torch._check(
90
+ grad.dtype == rois.dtype,
91
+ lambda: (
92
+ "Expected tensor for grad to have the same type as tensor for rois; "
93
+ f"but type {grad.dtype} does not equal {rois.dtype}"
94
+ ),
95
+ )
96
+ return grad.new_empty((batch_size, channels, height, width))
97
+
98
+
99
+ @register_meta("roi_pool")
100
+ def meta_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
101
+ torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
102
+ torch._check(
103
+ input.dtype == rois.dtype,
104
+ lambda: (
105
+ "Expected tensor for input to have the same type as tensor for rois; "
106
+ f"but type {input.dtype} does not equal {rois.dtype}"
107
+ ),
108
+ )
109
+ num_rois = rois.size(0)
110
+ channels = input.size(1)
111
+ out_size = (num_rois, channels, pooled_height, pooled_width)
112
+ return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
113
+
114
+
115
+ @register_meta("_roi_pool_backward")
116
+ def meta_roi_pool_backward(
117
+ grad, rois, argmax, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
118
+ ):
119
+ torch._check(
120
+ grad.dtype == rois.dtype,
121
+ lambda: (
122
+ "Expected tensor for grad to have the same type as tensor for rois; "
123
+ f"but type {grad.dtype} does not equal {rois.dtype}"
124
+ ),
125
+ )
126
+ return grad.new_empty((batch_size, channels, height, width))
127
+
128
+
129
+ @register_meta("ps_roi_pool")
130
+ def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width):
131
+ torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
132
+ torch._check(
133
+ input.dtype == rois.dtype,
134
+ lambda: (
135
+ "Expected tensor for input to have the same type as tensor for rois; "
136
+ f"but type {input.dtype} does not equal {rois.dtype}"
137
+ ),
138
+ )
139
+ channels = input.size(1)
140
+ torch._check(
141
+ channels % (pooled_height * pooled_width) == 0,
142
+ "input channels must be a multiple of pooling height * pooling width",
143
+ )
144
+ num_rois = rois.size(0)
145
+ out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
146
+ return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32)
147
+
148
+
149
+ @register_meta("_ps_roi_pool_backward")
150
+ def meta_ps_roi_pool_backward(
151
+ grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width
152
+ ):
153
+ torch._check(
154
+ grad.dtype == rois.dtype,
155
+ lambda: (
156
+ "Expected tensor for grad to have the same type as tensor for rois; "
157
+ f"but type {grad.dtype} does not equal {rois.dtype}"
158
+ ),
159
+ )
160
+ return grad.new_empty((batch_size, channels, height, width))
161
+
162
+
163
+ @torch.library.register_fake("torchvision::nms")
164
+ def meta_nms(dets, scores, iou_threshold):
165
+ torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
166
+ torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}")
167
+ torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}")
168
+ torch._check(
169
+ dets.size(0) == scores.size(0),
170
+ lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}",
171
+ )
172
+ ctx = torch._custom_ops.get_ctx()
173
+ num_to_keep = ctx.create_unbacked_symint()
174
+ return dets.new_empty(num_to_keep, dtype=torch.long)
175
+
176
+
177
+ @register_meta("deform_conv2d")
178
+ def meta_deform_conv2d(
179
+ input,
180
+ weight,
181
+ offset,
182
+ mask,
183
+ bias,
184
+ stride_h,
185
+ stride_w,
186
+ pad_h,
187
+ pad_w,
188
+ dil_h,
189
+ dil_w,
190
+ n_weight_grps,
191
+ n_offset_grps,
192
+ use_mask,
193
+ ):
194
+
195
+ out_height, out_width = offset.shape[-2:]
196
+ out_channels = weight.shape[0]
197
+ batch_size = input.shape[0]
198
+ return input.new_empty((batch_size, out_channels, out_height, out_width))
199
+
200
+
201
+ @register_meta("_deform_conv2d_backward")
202
+ def meta_deform_conv2d_backward(
203
+ grad,
204
+ input,
205
+ weight,
206
+ offset,
207
+ mask,
208
+ bias,
209
+ stride_h,
210
+ stride_w,
211
+ pad_h,
212
+ pad_w,
213
+ dilation_h,
214
+ dilation_w,
215
+ groups,
216
+ offset_groups,
217
+ use_mask,
218
+ ):
219
+
220
+ grad_input = input.new_empty(input.shape)
221
+ grad_weight = weight.new_empty(weight.shape)
222
+ grad_offset = offset.new_empty(offset.shape)
223
+ grad_mask = mask.new_empty(mask.shape)
224
+ grad_bias = bias.new_empty(bias.shape)
225
+ return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
.venv/lib/python3.11/site-packages/torchvision/_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from typing import Sequence, Type, TypeVar
3
+
4
+ T = TypeVar("T", bound=enum.Enum)
5
+
6
+
7
+ class StrEnumMeta(enum.EnumMeta):
8
+ auto = enum.auto
9
+
10
+ def from_str(self: Type[T], member: str) -> T: # type: ignore[misc]
11
+ try:
12
+ return self[member]
13
+ except KeyError:
14
+ # TODO: use `add_suggestion` from torchvision.prototype.utils._internal to improve the error message as
15
+ # soon as it is migrated.
16
+ raise ValueError(f"Unknown value '{member}' for {self.__name__}.") from None
17
+
18
+
19
+ class StrEnum(enum.Enum, metaclass=StrEnumMeta):
20
+ pass
21
+
22
+
23
+ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str:
24
+ if not seq:
25
+ return ""
26
+ if len(seq) == 1:
27
+ return f"'{seq[0]}'"
28
+
29
+ head = "'" + "', '".join([str(item) for item in seq[:-1]]) + "'"
30
+ tail = f"{'' if separate_last and len(seq) == 2 else ','} {separate_last}'{seq[-1]}'"
31
+
32
+ return head + tail
.venv/lib/python3.11/site-packages/torchvision/extension.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import torch
5
+
6
+ from ._internally_replaced_utils import _get_extension_path
7
+
8
+
9
+ _HAS_OPS = False
10
+
11
+
12
+ def _has_ops():
13
+ return False
14
+
15
+
16
+ try:
17
+ # On Windows Python-3.8.x has `os.add_dll_directory` call,
18
+ # which is called to configure dll search path.
19
+ # To find cuda related dlls we need to make sure the
20
+ # conda environment/bin path is configured Please take a look:
21
+ # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
22
+ # Please note: if some path can't be added using add_dll_directory we simply ignore this path
23
+ if os.name == "nt" and sys.version_info < (3, 9):
24
+ env_path = os.environ["PATH"]
25
+ path_arr = env_path.split(";")
26
+ for path in path_arr:
27
+ if os.path.exists(path):
28
+ try:
29
+ os.add_dll_directory(path) # type: ignore[attr-defined]
30
+ except Exception:
31
+ pass
32
+
33
+ lib_path = _get_extension_path("_C")
34
+ torch.ops.load_library(lib_path)
35
+ _HAS_OPS = True
36
+
37
+ def _has_ops(): # noqa: F811
38
+ return True
39
+
40
+ except (ImportError, OSError):
41
+ pass
42
+
43
+
44
+ def _assert_has_ops():
45
+ if not _has_ops():
46
+ raise RuntimeError(
47
+ "Couldn't load custom C++ ops. This can happen if your PyTorch and "
48
+ "torchvision versions are incompatible, or if you had errors while compiling "
49
+ "torchvision from source. For further information on the compatible versions, check "
50
+ "https://github.com/pytorch/vision#installation for the compatibility matrix. "
51
+ "Please check your PyTorch version with torch.__version__ and your torchvision "
52
+ "version with torchvision.__version__ and verify if they are compatible, and if not "
53
+ "please reinstall torchvision so that it matches your PyTorch install."
54
+ )
55
+
56
+
57
+ def _check_cuda_version():
58
+ """
59
+ Make sure that CUDA versions match between the pytorch install and torchvision install
60
+ """
61
+ if not _HAS_OPS:
62
+ return -1
63
+ from torch.version import cuda as torch_version_cuda
64
+
65
+ _version = torch.ops.torchvision._cuda_version()
66
+ if _version != -1 and torch_version_cuda is not None:
67
+ tv_version = str(_version)
68
+ if int(tv_version) < 10000:
69
+ tv_major = int(tv_version[0])
70
+ tv_minor = int(tv_version[2])
71
+ else:
72
+ tv_major = int(tv_version[0:2])
73
+ tv_minor = int(tv_version[3])
74
+ t_version = torch_version_cuda.split(".")
75
+ t_major = int(t_version[0])
76
+ t_minor = int(t_version[1])
77
+ if t_major != tv_major:
78
+ raise RuntimeError(
79
+ "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
80
+ f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
81
+ f"CUDA Version={tv_major}.{tv_minor}. "
82
+ "Please reinstall the torchvision that matches your PyTorch install."
83
+ )
84
+ return _version
85
+
86
+
87
+ def _load_library(lib_name):
88
+ lib_path = _get_extension_path(lib_name)
89
+ torch.ops.load_library(lib_path)
90
+
91
+
92
+ _check_cuda_version()
.venv/lib/python3.11/site-packages/torchvision/models/alexnet.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..transforms._presets import ImageClassification
8
+ from ..utils import _log_api_usage_once
9
+ from ._api import register_model, Weights, WeightsEnum
10
+ from ._meta import _IMAGENET_CATEGORIES
11
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
12
+
13
+
14
+ __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
15
+
16
+
17
+ class AlexNet(nn.Module):
18
+ def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:
19
+ super().__init__()
20
+ _log_api_usage_once(self)
21
+ self.features = nn.Sequential(
22
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
23
+ nn.ReLU(inplace=True),
24
+ nn.MaxPool2d(kernel_size=3, stride=2),
25
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
26
+ nn.ReLU(inplace=True),
27
+ nn.MaxPool2d(kernel_size=3, stride=2),
28
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
31
+ nn.ReLU(inplace=True),
32
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
33
+ nn.ReLU(inplace=True),
34
+ nn.MaxPool2d(kernel_size=3, stride=2),
35
+ )
36
+ self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
37
+ self.classifier = nn.Sequential(
38
+ nn.Dropout(p=dropout),
39
+ nn.Linear(256 * 6 * 6, 4096),
40
+ nn.ReLU(inplace=True),
41
+ nn.Dropout(p=dropout),
42
+ nn.Linear(4096, 4096),
43
+ nn.ReLU(inplace=True),
44
+ nn.Linear(4096, num_classes),
45
+ )
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ x = self.features(x)
49
+ x = self.avgpool(x)
50
+ x = torch.flatten(x, 1)
51
+ x = self.classifier(x)
52
+ return x
53
+
54
+
55
+ class AlexNet_Weights(WeightsEnum):
56
+ IMAGENET1K_V1 = Weights(
57
+ url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
58
+ transforms=partial(ImageClassification, crop_size=224),
59
+ meta={
60
+ "num_params": 61100840,
61
+ "min_size": (63, 63),
62
+ "categories": _IMAGENET_CATEGORIES,
63
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
64
+ "_metrics": {
65
+ "ImageNet-1K": {
66
+ "acc@1": 56.522,
67
+ "acc@5": 79.066,
68
+ }
69
+ },
70
+ "_ops": 0.714,
71
+ "_file_size": 233.087,
72
+ "_docs": """
73
+ These weights reproduce closely the results of the paper using a simplified training recipe.
74
+ """,
75
+ },
76
+ )
77
+ DEFAULT = IMAGENET1K_V1
78
+
79
+
80
+ @register_model()
81
+ @handle_legacy_interface(weights=("pretrained", AlexNet_Weights.IMAGENET1K_V1))
82
+ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, **kwargs: Any) -> AlexNet:
83
+ """AlexNet model architecture from `One weird trick for parallelizing convolutional neural networks <https://arxiv.org/abs/1404.5997>`__.
84
+
85
+ .. note::
86
+ AlexNet was originally introduced in the `ImageNet Classification with
87
+ Deep Convolutional Neural Networks
88
+ <https://papers.nips.cc/paper/2012/hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html>`__
89
+ paper. Our implementation is based instead on the "One weird trick"
90
+ paper above.
91
+
92
+ Args:
93
+ weights (:class:`~torchvision.models.AlexNet_Weights`, optional): The
94
+ pretrained weights to use. See
95
+ :class:`~torchvision.models.AlexNet_Weights` below for
96
+ more details, and possible values. By default, no pre-trained
97
+ weights are used.
98
+ progress (bool, optional): If True, displays a progress bar of the
99
+ download to stderr. Default is True.
100
+ **kwargs: parameters passed to the ``torchvision.models.squeezenet.AlexNet``
101
+ base class. Please refer to the `source code
102
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/alexnet.py>`_
103
+ for more details about this class.
104
+
105
+ .. autoclass:: torchvision.models.AlexNet_Weights
106
+ :members:
107
+ """
108
+
109
+ weights = AlexNet_Weights.verify(weights)
110
+
111
+ if weights is not None:
112
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
113
+
114
+ model = AlexNet(**kwargs)
115
+
116
+ if weights is not None:
117
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
118
+
119
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/convnext.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, List, Optional, Sequence
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.nn import functional as F
7
+
8
+ from ..ops.misc import Conv2dNormActivation, Permute
9
+ from ..ops.stochastic_depth import StochasticDepth
10
+ from ..transforms._presets import ImageClassification
11
+ from ..utils import _log_api_usage_once
12
+ from ._api import register_model, Weights, WeightsEnum
13
+ from ._meta import _IMAGENET_CATEGORIES
14
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
15
+
16
+
17
+ __all__ = [
18
+ "ConvNeXt",
19
+ "ConvNeXt_Tiny_Weights",
20
+ "ConvNeXt_Small_Weights",
21
+ "ConvNeXt_Base_Weights",
22
+ "ConvNeXt_Large_Weights",
23
+ "convnext_tiny",
24
+ "convnext_small",
25
+ "convnext_base",
26
+ "convnext_large",
27
+ ]
28
+
29
+
30
+ class LayerNorm2d(nn.LayerNorm):
31
+ def forward(self, x: Tensor) -> Tensor:
32
+ x = x.permute(0, 2, 3, 1)
33
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
34
+ x = x.permute(0, 3, 1, 2)
35
+ return x
36
+
37
+
38
+ class CNBlock(nn.Module):
39
+ def __init__(
40
+ self,
41
+ dim,
42
+ layer_scale: float,
43
+ stochastic_depth_prob: float,
44
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
45
+ ) -> None:
46
+ super().__init__()
47
+ if norm_layer is None:
48
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
49
+
50
+ self.block = nn.Sequential(
51
+ nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim, bias=True),
52
+ Permute([0, 2, 3, 1]),
53
+ norm_layer(dim),
54
+ nn.Linear(in_features=dim, out_features=4 * dim, bias=True),
55
+ nn.GELU(),
56
+ nn.Linear(in_features=4 * dim, out_features=dim, bias=True),
57
+ Permute([0, 3, 1, 2]),
58
+ )
59
+ self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale)
60
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
61
+
62
+ def forward(self, input: Tensor) -> Tensor:
63
+ result = self.layer_scale * self.block(input)
64
+ result = self.stochastic_depth(result)
65
+ result += input
66
+ return result
67
+
68
+
69
+ class CNBlockConfig:
70
+ # Stores information listed at Section 3 of the ConvNeXt paper
71
+ def __init__(
72
+ self,
73
+ input_channels: int,
74
+ out_channels: Optional[int],
75
+ num_layers: int,
76
+ ) -> None:
77
+ self.input_channels = input_channels
78
+ self.out_channels = out_channels
79
+ self.num_layers = num_layers
80
+
81
+ def __repr__(self) -> str:
82
+ s = self.__class__.__name__ + "("
83
+ s += "input_channels={input_channels}"
84
+ s += ", out_channels={out_channels}"
85
+ s += ", num_layers={num_layers}"
86
+ s += ")"
87
+ return s.format(**self.__dict__)
88
+
89
+
90
+ class ConvNeXt(nn.Module):
91
+ def __init__(
92
+ self,
93
+ block_setting: List[CNBlockConfig],
94
+ stochastic_depth_prob: float = 0.0,
95
+ layer_scale: float = 1e-6,
96
+ num_classes: int = 1000,
97
+ block: Optional[Callable[..., nn.Module]] = None,
98
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
99
+ **kwargs: Any,
100
+ ) -> None:
101
+ super().__init__()
102
+ _log_api_usage_once(self)
103
+
104
+ if not block_setting:
105
+ raise ValueError("The block_setting should not be empty")
106
+ elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])):
107
+ raise TypeError("The block_setting should be List[CNBlockConfig]")
108
+
109
+ if block is None:
110
+ block = CNBlock
111
+
112
+ if norm_layer is None:
113
+ norm_layer = partial(LayerNorm2d, eps=1e-6)
114
+
115
+ layers: List[nn.Module] = []
116
+
117
+ # Stem
118
+ firstconv_output_channels = block_setting[0].input_channels
119
+ layers.append(
120
+ Conv2dNormActivation(
121
+ 3,
122
+ firstconv_output_channels,
123
+ kernel_size=4,
124
+ stride=4,
125
+ padding=0,
126
+ norm_layer=norm_layer,
127
+ activation_layer=None,
128
+ bias=True,
129
+ )
130
+ )
131
+
132
+ total_stage_blocks = sum(cnf.num_layers for cnf in block_setting)
133
+ stage_block_id = 0
134
+ for cnf in block_setting:
135
+ # Bottlenecks
136
+ stage: List[nn.Module] = []
137
+ for _ in range(cnf.num_layers):
138
+ # adjust stochastic depth probability based on the depth of the stage block
139
+ sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0)
140
+ stage.append(block(cnf.input_channels, layer_scale, sd_prob))
141
+ stage_block_id += 1
142
+ layers.append(nn.Sequential(*stage))
143
+ if cnf.out_channels is not None:
144
+ # Downsampling
145
+ layers.append(
146
+ nn.Sequential(
147
+ norm_layer(cnf.input_channels),
148
+ nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2),
149
+ )
150
+ )
151
+
152
+ self.features = nn.Sequential(*layers)
153
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
154
+
155
+ lastblock = block_setting[-1]
156
+ lastconv_output_channels = (
157
+ lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels
158
+ )
159
+ self.classifier = nn.Sequential(
160
+ norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes)
161
+ )
162
+
163
+ for m in self.modules():
164
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
165
+ nn.init.trunc_normal_(m.weight, std=0.02)
166
+ if m.bias is not None:
167
+ nn.init.zeros_(m.bias)
168
+
169
+ def _forward_impl(self, x: Tensor) -> Tensor:
170
+ x = self.features(x)
171
+ x = self.avgpool(x)
172
+ x = self.classifier(x)
173
+ return x
174
+
175
+ def forward(self, x: Tensor) -> Tensor:
176
+ return self._forward_impl(x)
177
+
178
+
179
+ def _convnext(
180
+ block_setting: List[CNBlockConfig],
181
+ stochastic_depth_prob: float,
182
+ weights: Optional[WeightsEnum],
183
+ progress: bool,
184
+ **kwargs: Any,
185
+ ) -> ConvNeXt:
186
+ if weights is not None:
187
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
188
+
189
+ model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
190
+
191
+ if weights is not None:
192
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
193
+
194
+ return model
195
+
196
+
197
+ _COMMON_META = {
198
+ "min_size": (32, 32),
199
+ "categories": _IMAGENET_CATEGORIES,
200
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext",
201
+ "_docs": """
202
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
203
+ `new training recipe
204
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
205
+ """,
206
+ }
207
+
208
+
209
+ class ConvNeXt_Tiny_Weights(WeightsEnum):
210
+ IMAGENET1K_V1 = Weights(
211
+ url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
212
+ transforms=partial(ImageClassification, crop_size=224, resize_size=236),
213
+ meta={
214
+ **_COMMON_META,
215
+ "num_params": 28589128,
216
+ "_metrics": {
217
+ "ImageNet-1K": {
218
+ "acc@1": 82.520,
219
+ "acc@5": 96.146,
220
+ }
221
+ },
222
+ "_ops": 4.456,
223
+ "_file_size": 109.119,
224
+ },
225
+ )
226
+ DEFAULT = IMAGENET1K_V1
227
+
228
+
229
+ class ConvNeXt_Small_Weights(WeightsEnum):
230
+ IMAGENET1K_V1 = Weights(
231
+ url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
232
+ transforms=partial(ImageClassification, crop_size=224, resize_size=230),
233
+ meta={
234
+ **_COMMON_META,
235
+ "num_params": 50223688,
236
+ "_metrics": {
237
+ "ImageNet-1K": {
238
+ "acc@1": 83.616,
239
+ "acc@5": 96.650,
240
+ }
241
+ },
242
+ "_ops": 8.684,
243
+ "_file_size": 191.703,
244
+ },
245
+ )
246
+ DEFAULT = IMAGENET1K_V1
247
+
248
+
249
+ class ConvNeXt_Base_Weights(WeightsEnum):
250
+ IMAGENET1K_V1 = Weights(
251
+ url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
252
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
253
+ meta={
254
+ **_COMMON_META,
255
+ "num_params": 88591464,
256
+ "_metrics": {
257
+ "ImageNet-1K": {
258
+ "acc@1": 84.062,
259
+ "acc@5": 96.870,
260
+ }
261
+ },
262
+ "_ops": 15.355,
263
+ "_file_size": 338.064,
264
+ },
265
+ )
266
+ DEFAULT = IMAGENET1K_V1
267
+
268
+
269
+ class ConvNeXt_Large_Weights(WeightsEnum):
270
+ IMAGENET1K_V1 = Weights(
271
+ url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
272
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
273
+ meta={
274
+ **_COMMON_META,
275
+ "num_params": 197767336,
276
+ "_metrics": {
277
+ "ImageNet-1K": {
278
+ "acc@1": 84.414,
279
+ "acc@5": 96.976,
280
+ }
281
+ },
282
+ "_ops": 34.361,
283
+ "_file_size": 754.537,
284
+ },
285
+ )
286
+ DEFAULT = IMAGENET1K_V1
287
+
288
+
289
+ @register_model()
290
+ @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.IMAGENET1K_V1))
291
+ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
292
+ """ConvNeXt Tiny model architecture from the
293
+ `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
294
+
295
+ Args:
296
+ weights (:class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`, optional): The pretrained
297
+ weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Tiny_Weights`
298
+ below for more details and possible values. By default, no pre-trained weights are used.
299
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
300
+ **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
301
+ base class. Please refer to the `source code
302
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
303
+ for more details about this class.
304
+
305
+ .. autoclass:: torchvision.models.ConvNeXt_Tiny_Weights
306
+ :members:
307
+ """
308
+ weights = ConvNeXt_Tiny_Weights.verify(weights)
309
+
310
+ block_setting = [
311
+ CNBlockConfig(96, 192, 3),
312
+ CNBlockConfig(192, 384, 3),
313
+ CNBlockConfig(384, 768, 9),
314
+ CNBlockConfig(768, None, 3),
315
+ ]
316
+ stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1)
317
+ return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
318
+
319
+
320
+ @register_model()
321
+ @handle_legacy_interface(weights=("pretrained", ConvNeXt_Small_Weights.IMAGENET1K_V1))
322
+ def convnext_small(
323
+ *, weights: Optional[ConvNeXt_Small_Weights] = None, progress: bool = True, **kwargs: Any
324
+ ) -> ConvNeXt:
325
+ """ConvNeXt Small model architecture from the
326
+ `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
327
+
328
+ Args:
329
+ weights (:class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`, optional): The pretrained
330
+ weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Small_Weights`
331
+ below for more details and possible values. By default, no pre-trained weights are used.
332
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
333
+ **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
334
+ base class. Please refer to the `source code
335
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
336
+ for more details about this class.
337
+
338
+ .. autoclass:: torchvision.models.ConvNeXt_Small_Weights
339
+ :members:
340
+ """
341
+ weights = ConvNeXt_Small_Weights.verify(weights)
342
+
343
+ block_setting = [
344
+ CNBlockConfig(96, 192, 3),
345
+ CNBlockConfig(192, 384, 3),
346
+ CNBlockConfig(384, 768, 27),
347
+ CNBlockConfig(768, None, 3),
348
+ ]
349
+ stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.4)
350
+ return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
351
+
352
+
353
+ @register_model()
354
+ @handle_legacy_interface(weights=("pretrained", ConvNeXt_Base_Weights.IMAGENET1K_V1))
355
+ def convnext_base(*, weights: Optional[ConvNeXt_Base_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt:
356
+ """ConvNeXt Base model architecture from the
357
+ `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
358
+
359
+ Args:
360
+ weights (:class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`, optional): The pretrained
361
+ weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Base_Weights`
362
+ below for more details and possible values. By default, no pre-trained weights are used.
363
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
364
+ **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
365
+ base class. Please refer to the `source code
366
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
367
+ for more details about this class.
368
+
369
+ .. autoclass:: torchvision.models.ConvNeXt_Base_Weights
370
+ :members:
371
+ """
372
+ weights = ConvNeXt_Base_Weights.verify(weights)
373
+
374
+ block_setting = [
375
+ CNBlockConfig(128, 256, 3),
376
+ CNBlockConfig(256, 512, 3),
377
+ CNBlockConfig(512, 1024, 27),
378
+ CNBlockConfig(1024, None, 3),
379
+ ]
380
+ stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
381
+ return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
382
+
383
+
384
+ @register_model()
385
+ @handle_legacy_interface(weights=("pretrained", ConvNeXt_Large_Weights.IMAGENET1K_V1))
386
+ def convnext_large(
387
+ *, weights: Optional[ConvNeXt_Large_Weights] = None, progress: bool = True, **kwargs: Any
388
+ ) -> ConvNeXt:
389
+ """ConvNeXt Large model architecture from the
390
+ `A ConvNet for the 2020s <https://arxiv.org/abs/2201.03545>`_ paper.
391
+
392
+ Args:
393
+ weights (:class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`, optional): The pretrained
394
+ weights to use. See :class:`~torchvision.models.convnext.ConvNeXt_Large_Weights`
395
+ below for more details and possible values. By default, no pre-trained weights are used.
396
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
397
+ **kwargs: parameters passed to the ``torchvision.models.convnext.ConvNext``
398
+ base class. Please refer to the `source code
399
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/convnext.py>`_
400
+ for more details about this class.
401
+
402
+ .. autoclass:: torchvision.models.ConvNeXt_Large_Weights
403
+ :members:
404
+ """
405
+ weights = ConvNeXt_Large_Weights.verify(weights)
406
+
407
+ block_setting = [
408
+ CNBlockConfig(192, 384, 3),
409
+ CNBlockConfig(384, 768, 3),
410
+ CNBlockConfig(768, 1536, 27),
411
+ CNBlockConfig(1536, None, 3),
412
+ ]
413
+ stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.5)
414
+ return _convnext(block_setting, stochastic_depth_prob, weights, progress, **kwargs)
.venv/lib/python3.11/site-packages/torchvision/models/densenet.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+ from typing import Any, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint as cp
10
+ from torch import Tensor
11
+
12
+ from ..transforms._presets import ImageClassification
13
+ from ..utils import _log_api_usage_once
14
+ from ._api import register_model, Weights, WeightsEnum
15
+ from ._meta import _IMAGENET_CATEGORIES
16
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
17
+
18
+ __all__ = [
19
+ "DenseNet",
20
+ "DenseNet121_Weights",
21
+ "DenseNet161_Weights",
22
+ "DenseNet169_Weights",
23
+ "DenseNet201_Weights",
24
+ "densenet121",
25
+ "densenet161",
26
+ "densenet169",
27
+ "densenet201",
28
+ ]
29
+
30
+
31
+ class _DenseLayer(nn.Module):
32
+ def __init__(
33
+ self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
34
+ ) -> None:
35
+ super().__init__()
36
+ self.norm1 = nn.BatchNorm2d(num_input_features)
37
+ self.relu1 = nn.ReLU(inplace=True)
38
+ self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
39
+
40
+ self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
41
+ self.relu2 = nn.ReLU(inplace=True)
42
+ self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
43
+
44
+ self.drop_rate = float(drop_rate)
45
+ self.memory_efficient = memory_efficient
46
+
47
+ def bn_function(self, inputs: List[Tensor]) -> Tensor:
48
+ concated_features = torch.cat(inputs, 1)
49
+ bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
50
+ return bottleneck_output
51
+
52
+ # todo: rewrite when torchscript supports any
53
+ def any_requires_grad(self, input: List[Tensor]) -> bool:
54
+ for tensor in input:
55
+ if tensor.requires_grad:
56
+ return True
57
+ return False
58
+
59
+ @torch.jit.unused # noqa: T484
60
+ def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
61
+ def closure(*inputs):
62
+ return self.bn_function(inputs)
63
+
64
+ return cp.checkpoint(closure, *input, use_reentrant=False)
65
+
66
+ @torch.jit._overload_method # noqa: F811
67
+ def forward(self, input: List[Tensor]) -> Tensor: # noqa: F811
68
+ pass
69
+
70
+ @torch.jit._overload_method # noqa: F811
71
+ def forward(self, input: Tensor) -> Tensor: # noqa: F811
72
+ pass
73
+
74
+ # torchscript does not yet support *args, so we overload method
75
+ # allowing it to take either a List[Tensor] or single Tensor
76
+ def forward(self, input: Tensor) -> Tensor: # noqa: F811
77
+ if isinstance(input, Tensor):
78
+ prev_features = [input]
79
+ else:
80
+ prev_features = input
81
+
82
+ if self.memory_efficient and self.any_requires_grad(prev_features):
83
+ if torch.jit.is_scripting():
84
+ raise Exception("Memory Efficient not supported in JIT")
85
+
86
+ bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
87
+ else:
88
+ bottleneck_output = self.bn_function(prev_features)
89
+
90
+ new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
91
+ if self.drop_rate > 0:
92
+ new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
93
+ return new_features
94
+
95
+
96
+ class _DenseBlock(nn.ModuleDict):
97
+ _version = 2
98
+
99
+ def __init__(
100
+ self,
101
+ num_layers: int,
102
+ num_input_features: int,
103
+ bn_size: int,
104
+ growth_rate: int,
105
+ drop_rate: float,
106
+ memory_efficient: bool = False,
107
+ ) -> None:
108
+ super().__init__()
109
+ for i in range(num_layers):
110
+ layer = _DenseLayer(
111
+ num_input_features + i * growth_rate,
112
+ growth_rate=growth_rate,
113
+ bn_size=bn_size,
114
+ drop_rate=drop_rate,
115
+ memory_efficient=memory_efficient,
116
+ )
117
+ self.add_module("denselayer%d" % (i + 1), layer)
118
+
119
+ def forward(self, init_features: Tensor) -> Tensor:
120
+ features = [init_features]
121
+ for name, layer in self.items():
122
+ new_features = layer(features)
123
+ features.append(new_features)
124
+ return torch.cat(features, 1)
125
+
126
+
127
+ class _Transition(nn.Sequential):
128
+ def __init__(self, num_input_features: int, num_output_features: int) -> None:
129
+ super().__init__()
130
+ self.norm = nn.BatchNorm2d(num_input_features)
131
+ self.relu = nn.ReLU(inplace=True)
132
+ self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)
133
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
134
+
135
+
136
+ class DenseNet(nn.Module):
137
+ r"""Densenet-BC model class, based on
138
+ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_.
139
+
140
+ Args:
141
+ growth_rate (int) - how many filters to add each layer (`k` in paper)
142
+ block_config (list of 4 ints) - how many layers in each pooling block
143
+ num_init_features (int) - the number of filters to learn in the first convolution layer
144
+ bn_size (int) - multiplicative factor for number of bottle neck layers
145
+ (i.e. bn_size * k features in the bottleneck layer)
146
+ drop_rate (float) - dropout rate after each dense layer
147
+ num_classes (int) - number of classification classes
148
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
149
+ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ growth_rate: int = 32,
155
+ block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
156
+ num_init_features: int = 64,
157
+ bn_size: int = 4,
158
+ drop_rate: float = 0,
159
+ num_classes: int = 1000,
160
+ memory_efficient: bool = False,
161
+ ) -> None:
162
+
163
+ super().__init__()
164
+ _log_api_usage_once(self)
165
+
166
+ # First convolution
167
+ self.features = nn.Sequential(
168
+ OrderedDict(
169
+ [
170
+ ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
171
+ ("norm0", nn.BatchNorm2d(num_init_features)),
172
+ ("relu0", nn.ReLU(inplace=True)),
173
+ ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
174
+ ]
175
+ )
176
+ )
177
+
178
+ # Each denseblock
179
+ num_features = num_init_features
180
+ for i, num_layers in enumerate(block_config):
181
+ block = _DenseBlock(
182
+ num_layers=num_layers,
183
+ num_input_features=num_features,
184
+ bn_size=bn_size,
185
+ growth_rate=growth_rate,
186
+ drop_rate=drop_rate,
187
+ memory_efficient=memory_efficient,
188
+ )
189
+ self.features.add_module("denseblock%d" % (i + 1), block)
190
+ num_features = num_features + num_layers * growth_rate
191
+ if i != len(block_config) - 1:
192
+ trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
193
+ self.features.add_module("transition%d" % (i + 1), trans)
194
+ num_features = num_features // 2
195
+
196
+ # Final batch norm
197
+ self.features.add_module("norm5", nn.BatchNorm2d(num_features))
198
+
199
+ # Linear layer
200
+ self.classifier = nn.Linear(num_features, num_classes)
201
+
202
+ # Official init from torch repo.
203
+ for m in self.modules():
204
+ if isinstance(m, nn.Conv2d):
205
+ nn.init.kaiming_normal_(m.weight)
206
+ elif isinstance(m, nn.BatchNorm2d):
207
+ nn.init.constant_(m.weight, 1)
208
+ nn.init.constant_(m.bias, 0)
209
+ elif isinstance(m, nn.Linear):
210
+ nn.init.constant_(m.bias, 0)
211
+
212
+ def forward(self, x: Tensor) -> Tensor:
213
+ features = self.features(x)
214
+ out = F.relu(features, inplace=True)
215
+ out = F.adaptive_avg_pool2d(out, (1, 1))
216
+ out = torch.flatten(out, 1)
217
+ out = self.classifier(out)
218
+ return out
219
+
220
+
221
+ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> None:
222
+ # '.'s are no longer allowed in module names, but previous _DenseLayer
223
+ # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
224
+ # They are also in the checkpoints in model_urls. This pattern is used
225
+ # to find such keys.
226
+ pattern = re.compile(
227
+ r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
228
+ )
229
+
230
+ state_dict = weights.get_state_dict(progress=progress, check_hash=True)
231
+ for key in list(state_dict.keys()):
232
+ res = pattern.match(key)
233
+ if res:
234
+ new_key = res.group(1) + res.group(2)
235
+ state_dict[new_key] = state_dict[key]
236
+ del state_dict[key]
237
+ model.load_state_dict(state_dict)
238
+
239
+
240
+ def _densenet(
241
+ growth_rate: int,
242
+ block_config: Tuple[int, int, int, int],
243
+ num_init_features: int,
244
+ weights: Optional[WeightsEnum],
245
+ progress: bool,
246
+ **kwargs: Any,
247
+ ) -> DenseNet:
248
+ if weights is not None:
249
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
250
+
251
+ model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
252
+
253
+ if weights is not None:
254
+ _load_state_dict(model=model, weights=weights, progress=progress)
255
+
256
+ return model
257
+
258
+
259
+ _COMMON_META = {
260
+ "min_size": (29, 29),
261
+ "categories": _IMAGENET_CATEGORIES,
262
+ "recipe": "https://github.com/pytorch/vision/pull/116",
263
+ "_docs": """These weights are ported from LuaTorch.""",
264
+ }
265
+
266
+
267
+ class DenseNet121_Weights(WeightsEnum):
268
+ IMAGENET1K_V1 = Weights(
269
+ url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
270
+ transforms=partial(ImageClassification, crop_size=224),
271
+ meta={
272
+ **_COMMON_META,
273
+ "num_params": 7978856,
274
+ "_metrics": {
275
+ "ImageNet-1K": {
276
+ "acc@1": 74.434,
277
+ "acc@5": 91.972,
278
+ }
279
+ },
280
+ "_ops": 2.834,
281
+ "_file_size": 30.845,
282
+ },
283
+ )
284
+ DEFAULT = IMAGENET1K_V1
285
+
286
+
287
+ class DenseNet161_Weights(WeightsEnum):
288
+ IMAGENET1K_V1 = Weights(
289
+ url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
290
+ transforms=partial(ImageClassification, crop_size=224),
291
+ meta={
292
+ **_COMMON_META,
293
+ "num_params": 28681000,
294
+ "_metrics": {
295
+ "ImageNet-1K": {
296
+ "acc@1": 77.138,
297
+ "acc@5": 93.560,
298
+ }
299
+ },
300
+ "_ops": 7.728,
301
+ "_file_size": 110.369,
302
+ },
303
+ )
304
+ DEFAULT = IMAGENET1K_V1
305
+
306
+
307
+ class DenseNet169_Weights(WeightsEnum):
308
+ IMAGENET1K_V1 = Weights(
309
+ url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
310
+ transforms=partial(ImageClassification, crop_size=224),
311
+ meta={
312
+ **_COMMON_META,
313
+ "num_params": 14149480,
314
+ "_metrics": {
315
+ "ImageNet-1K": {
316
+ "acc@1": 75.600,
317
+ "acc@5": 92.806,
318
+ }
319
+ },
320
+ "_ops": 3.36,
321
+ "_file_size": 54.708,
322
+ },
323
+ )
324
+ DEFAULT = IMAGENET1K_V1
325
+
326
+
327
+ class DenseNet201_Weights(WeightsEnum):
328
+ IMAGENET1K_V1 = Weights(
329
+ url="https://download.pytorch.org/models/densenet201-c1103571.pth",
330
+ transforms=partial(ImageClassification, crop_size=224),
331
+ meta={
332
+ **_COMMON_META,
333
+ "num_params": 20013928,
334
+ "_metrics": {
335
+ "ImageNet-1K": {
336
+ "acc@1": 76.896,
337
+ "acc@5": 93.370,
338
+ }
339
+ },
340
+ "_ops": 4.291,
341
+ "_file_size": 77.373,
342
+ },
343
+ )
344
+ DEFAULT = IMAGENET1K_V1
345
+
346
+
347
+ @register_model()
348
+ @handle_legacy_interface(weights=("pretrained", DenseNet121_Weights.IMAGENET1K_V1))
349
+ def densenet121(*, weights: Optional[DenseNet121_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
350
+ r"""Densenet-121 model from
351
+ `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
352
+
353
+ Args:
354
+ weights (:class:`~torchvision.models.DenseNet121_Weights`, optional): The
355
+ pretrained weights to use. See
356
+ :class:`~torchvision.models.DenseNet121_Weights` below for
357
+ more details, and possible values. By default, no pre-trained
358
+ weights are used.
359
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
360
+ **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
361
+ base class. Please refer to the `source code
362
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
363
+ for more details about this class.
364
+
365
+ .. autoclass:: torchvision.models.DenseNet121_Weights
366
+ :members:
367
+ """
368
+ weights = DenseNet121_Weights.verify(weights)
369
+
370
+ return _densenet(32, (6, 12, 24, 16), 64, weights, progress, **kwargs)
371
+
372
+
373
+ @register_model()
374
+ @handle_legacy_interface(weights=("pretrained", DenseNet161_Weights.IMAGENET1K_V1))
375
+ def densenet161(*, weights: Optional[DenseNet161_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
376
+ r"""Densenet-161 model from
377
+ `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
378
+
379
+ Args:
380
+ weights (:class:`~torchvision.models.DenseNet161_Weights`, optional): The
381
+ pretrained weights to use. See
382
+ :class:`~torchvision.models.DenseNet161_Weights` below for
383
+ more details, and possible values. By default, no pre-trained
384
+ weights are used.
385
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
386
+ **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
387
+ base class. Please refer to the `source code
388
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
389
+ for more details about this class.
390
+
391
+ .. autoclass:: torchvision.models.DenseNet161_Weights
392
+ :members:
393
+ """
394
+ weights = DenseNet161_Weights.verify(weights)
395
+
396
+ return _densenet(48, (6, 12, 36, 24), 96, weights, progress, **kwargs)
397
+
398
+
399
+ @register_model()
400
+ @handle_legacy_interface(weights=("pretrained", DenseNet169_Weights.IMAGENET1K_V1))
401
+ def densenet169(*, weights: Optional[DenseNet169_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
402
+ r"""Densenet-169 model from
403
+ `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
404
+
405
+ Args:
406
+ weights (:class:`~torchvision.models.DenseNet169_Weights`, optional): The
407
+ pretrained weights to use. See
408
+ :class:`~torchvision.models.DenseNet169_Weights` below for
409
+ more details, and possible values. By default, no pre-trained
410
+ weights are used.
411
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
412
+ **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
413
+ base class. Please refer to the `source code
414
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
415
+ for more details about this class.
416
+
417
+ .. autoclass:: torchvision.models.DenseNet169_Weights
418
+ :members:
419
+ """
420
+ weights = DenseNet169_Weights.verify(weights)
421
+
422
+ return _densenet(32, (6, 12, 32, 32), 64, weights, progress, **kwargs)
423
+
424
+
425
+ @register_model()
426
+ @handle_legacy_interface(weights=("pretrained", DenseNet201_Weights.IMAGENET1K_V1))
427
+ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool = True, **kwargs: Any) -> DenseNet:
428
+ r"""Densenet-201 model from
429
+ `Densely Connected Convolutional Networks <https://arxiv.org/abs/1608.06993>`_.
430
+
431
+ Args:
432
+ weights (:class:`~torchvision.models.DenseNet201_Weights`, optional): The
433
+ pretrained weights to use. See
434
+ :class:`~torchvision.models.DenseNet201_Weights` below for
435
+ more details, and possible values. By default, no pre-trained
436
+ weights are used.
437
+ progress (bool, optional): If True, displays a progress bar of the download to stderr. Default is True.
438
+ **kwargs: parameters passed to the ``torchvision.models.densenet.DenseNet``
439
+ base class. Please refer to the `source code
440
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/densenet.py>`_
441
+ for more details about this class.
442
+
443
+ .. autoclass:: torchvision.models.DenseNet201_Weights
444
+ :members:
445
+ """
446
+ weights = DenseNet201_Weights.verify(weights)
447
+
448
+ return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
.venv/lib/python3.11/site-packages/torchvision/models/efficientnet.py ADDED
@@ -0,0 +1,1131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from dataclasses import dataclass
4
+ from functools import partial
5
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn, Tensor
9
+ from torchvision.ops import StochasticDepth
10
+
11
+ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation
12
+ from ..transforms._presets import ImageClassification, InterpolationMode
13
+ from ..utils import _log_api_usage_once
14
+ from ._api import register_model, Weights, WeightsEnum
15
+ from ._meta import _IMAGENET_CATEGORIES
16
+ from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
17
+
18
+
19
+ __all__ = [
20
+ "EfficientNet",
21
+ "EfficientNet_B0_Weights",
22
+ "EfficientNet_B1_Weights",
23
+ "EfficientNet_B2_Weights",
24
+ "EfficientNet_B3_Weights",
25
+ "EfficientNet_B4_Weights",
26
+ "EfficientNet_B5_Weights",
27
+ "EfficientNet_B6_Weights",
28
+ "EfficientNet_B7_Weights",
29
+ "EfficientNet_V2_S_Weights",
30
+ "EfficientNet_V2_M_Weights",
31
+ "EfficientNet_V2_L_Weights",
32
+ "efficientnet_b0",
33
+ "efficientnet_b1",
34
+ "efficientnet_b2",
35
+ "efficientnet_b3",
36
+ "efficientnet_b4",
37
+ "efficientnet_b5",
38
+ "efficientnet_b6",
39
+ "efficientnet_b7",
40
+ "efficientnet_v2_s",
41
+ "efficientnet_v2_m",
42
+ "efficientnet_v2_l",
43
+ ]
44
+
45
+
46
+ @dataclass
47
+ class _MBConvConfig:
48
+ expand_ratio: float
49
+ kernel: int
50
+ stride: int
51
+ input_channels: int
52
+ out_channels: int
53
+ num_layers: int
54
+ block: Callable[..., nn.Module]
55
+
56
+ @staticmethod
57
+ def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
58
+ return _make_divisible(channels * width_mult, 8, min_value)
59
+
60
+
61
+ class MBConvConfig(_MBConvConfig):
62
+ # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
63
+ def __init__(
64
+ self,
65
+ expand_ratio: float,
66
+ kernel: int,
67
+ stride: int,
68
+ input_channels: int,
69
+ out_channels: int,
70
+ num_layers: int,
71
+ width_mult: float = 1.0,
72
+ depth_mult: float = 1.0,
73
+ block: Optional[Callable[..., nn.Module]] = None,
74
+ ) -> None:
75
+ input_channels = self.adjust_channels(input_channels, width_mult)
76
+ out_channels = self.adjust_channels(out_channels, width_mult)
77
+ num_layers = self.adjust_depth(num_layers, depth_mult)
78
+ if block is None:
79
+ block = MBConv
80
+ super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
81
+
82
+ @staticmethod
83
+ def adjust_depth(num_layers: int, depth_mult: float):
84
+ return int(math.ceil(num_layers * depth_mult))
85
+
86
+
87
+ class FusedMBConvConfig(_MBConvConfig):
88
+ # Stores information listed at Table 4 of the EfficientNetV2 paper
89
+ def __init__(
90
+ self,
91
+ expand_ratio: float,
92
+ kernel: int,
93
+ stride: int,
94
+ input_channels: int,
95
+ out_channels: int,
96
+ num_layers: int,
97
+ block: Optional[Callable[..., nn.Module]] = None,
98
+ ) -> None:
99
+ if block is None:
100
+ block = FusedMBConv
101
+ super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)
102
+
103
+
104
+ class MBConv(nn.Module):
105
+ def __init__(
106
+ self,
107
+ cnf: MBConvConfig,
108
+ stochastic_depth_prob: float,
109
+ norm_layer: Callable[..., nn.Module],
110
+ se_layer: Callable[..., nn.Module] = SqueezeExcitation,
111
+ ) -> None:
112
+ super().__init__()
113
+
114
+ if not (1 <= cnf.stride <= 2):
115
+ raise ValueError("illegal stride value")
116
+
117
+ self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
118
+
119
+ layers: List[nn.Module] = []
120
+ activation_layer = nn.SiLU
121
+
122
+ # expand
123
+ expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
124
+ if expanded_channels != cnf.input_channels:
125
+ layers.append(
126
+ Conv2dNormActivation(
127
+ cnf.input_channels,
128
+ expanded_channels,
129
+ kernel_size=1,
130
+ norm_layer=norm_layer,
131
+ activation_layer=activation_layer,
132
+ )
133
+ )
134
+
135
+ # depthwise
136
+ layers.append(
137
+ Conv2dNormActivation(
138
+ expanded_channels,
139
+ expanded_channels,
140
+ kernel_size=cnf.kernel,
141
+ stride=cnf.stride,
142
+ groups=expanded_channels,
143
+ norm_layer=norm_layer,
144
+ activation_layer=activation_layer,
145
+ )
146
+ )
147
+
148
+ # squeeze and excitation
149
+ squeeze_channels = max(1, cnf.input_channels // 4)
150
+ layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))
151
+
152
+ # project
153
+ layers.append(
154
+ Conv2dNormActivation(
155
+ expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
156
+ )
157
+ )
158
+
159
+ self.block = nn.Sequential(*layers)
160
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
161
+ self.out_channels = cnf.out_channels
162
+
163
+ def forward(self, input: Tensor) -> Tensor:
164
+ result = self.block(input)
165
+ if self.use_res_connect:
166
+ result = self.stochastic_depth(result)
167
+ result += input
168
+ return result
169
+
170
+
171
+ class FusedMBConv(nn.Module):
172
+ def __init__(
173
+ self,
174
+ cnf: FusedMBConvConfig,
175
+ stochastic_depth_prob: float,
176
+ norm_layer: Callable[..., nn.Module],
177
+ ) -> None:
178
+ super().__init__()
179
+
180
+ if not (1 <= cnf.stride <= 2):
181
+ raise ValueError("illegal stride value")
182
+
183
+ self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
184
+
185
+ layers: List[nn.Module] = []
186
+ activation_layer = nn.SiLU
187
+
188
+ expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
189
+ if expanded_channels != cnf.input_channels:
190
+ # fused expand
191
+ layers.append(
192
+ Conv2dNormActivation(
193
+ cnf.input_channels,
194
+ expanded_channels,
195
+ kernel_size=cnf.kernel,
196
+ stride=cnf.stride,
197
+ norm_layer=norm_layer,
198
+ activation_layer=activation_layer,
199
+ )
200
+ )
201
+
202
+ # project
203
+ layers.append(
204
+ Conv2dNormActivation(
205
+ expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
206
+ )
207
+ )
208
+ else:
209
+ layers.append(
210
+ Conv2dNormActivation(
211
+ cnf.input_channels,
212
+ cnf.out_channels,
213
+ kernel_size=cnf.kernel,
214
+ stride=cnf.stride,
215
+ norm_layer=norm_layer,
216
+ activation_layer=activation_layer,
217
+ )
218
+ )
219
+
220
+ self.block = nn.Sequential(*layers)
221
+ self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
222
+ self.out_channels = cnf.out_channels
223
+
224
+ def forward(self, input: Tensor) -> Tensor:
225
+ result = self.block(input)
226
+ if self.use_res_connect:
227
+ result = self.stochastic_depth(result)
228
+ result += input
229
+ return result
230
+
231
+
232
+ class EfficientNet(nn.Module):
233
+ def __init__(
234
+ self,
235
+ inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
236
+ dropout: float,
237
+ stochastic_depth_prob: float = 0.2,
238
+ num_classes: int = 1000,
239
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
240
+ last_channel: Optional[int] = None,
241
+ ) -> None:
242
+ """
243
+ EfficientNet V1 and V2 main class
244
+
245
+ Args:
246
+ inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
247
+ dropout (float): The droupout probability
248
+ stochastic_depth_prob (float): The stochastic depth probability
249
+ num_classes (int): Number of classes
250
+ norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
251
+ last_channel (int): The number of channels on the penultimate layer
252
+ """
253
+ super().__init__()
254
+ _log_api_usage_once(self)
255
+
256
+ if not inverted_residual_setting:
257
+ raise ValueError("The inverted_residual_setting should not be empty")
258
+ elif not (
259
+ isinstance(inverted_residual_setting, Sequence)
260
+ and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
261
+ ):
262
+ raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")
263
+
264
+ if norm_layer is None:
265
+ norm_layer = nn.BatchNorm2d
266
+
267
+ layers: List[nn.Module] = []
268
+
269
+ # building first layer
270
+ firstconv_output_channels = inverted_residual_setting[0].input_channels
271
+ layers.append(
272
+ Conv2dNormActivation(
273
+ 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
274
+ )
275
+ )
276
+
277
+ # building inverted residual blocks
278
+ total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
279
+ stage_block_id = 0
280
+ for cnf in inverted_residual_setting:
281
+ stage: List[nn.Module] = []
282
+ for _ in range(cnf.num_layers):
283
+ # copy to avoid modifications. shallow copy is enough
284
+ block_cnf = copy.copy(cnf)
285
+
286
+ # overwrite info if not the first conv in the stage
287
+ if stage:
288
+ block_cnf.input_channels = block_cnf.out_channels
289
+ block_cnf.stride = 1
290
+
291
+ # adjust stochastic depth probability based on the depth of the stage block
292
+ sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks
293
+
294
+ stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
295
+ stage_block_id += 1
296
+
297
+ layers.append(nn.Sequential(*stage))
298
+
299
+ # building last several layers
300
+ lastconv_input_channels = inverted_residual_setting[-1].out_channels
301
+ lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
302
+ layers.append(
303
+ Conv2dNormActivation(
304
+ lastconv_input_channels,
305
+ lastconv_output_channels,
306
+ kernel_size=1,
307
+ norm_layer=norm_layer,
308
+ activation_layer=nn.SiLU,
309
+ )
310
+ )
311
+
312
+ self.features = nn.Sequential(*layers)
313
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
314
+ self.classifier = nn.Sequential(
315
+ nn.Dropout(p=dropout, inplace=True),
316
+ nn.Linear(lastconv_output_channels, num_classes),
317
+ )
318
+
319
+ for m in self.modules():
320
+ if isinstance(m, nn.Conv2d):
321
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
322
+ if m.bias is not None:
323
+ nn.init.zeros_(m.bias)
324
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
325
+ nn.init.ones_(m.weight)
326
+ nn.init.zeros_(m.bias)
327
+ elif isinstance(m, nn.Linear):
328
+ init_range = 1.0 / math.sqrt(m.out_features)
329
+ nn.init.uniform_(m.weight, -init_range, init_range)
330
+ nn.init.zeros_(m.bias)
331
+
332
+ def _forward_impl(self, x: Tensor) -> Tensor:
333
+ x = self.features(x)
334
+
335
+ x = self.avgpool(x)
336
+ x = torch.flatten(x, 1)
337
+
338
+ x = self.classifier(x)
339
+
340
+ return x
341
+
342
+ def forward(self, x: Tensor) -> Tensor:
343
+ return self._forward_impl(x)
344
+
345
+
346
+ def _efficientnet(
347
+ inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
348
+ dropout: float,
349
+ last_channel: Optional[int],
350
+ weights: Optional[WeightsEnum],
351
+ progress: bool,
352
+ **kwargs: Any,
353
+ ) -> EfficientNet:
354
+ if weights is not None:
355
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
356
+
357
+ model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)
358
+
359
+ if weights is not None:
360
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
361
+
362
+ return model
363
+
364
+
365
+ def _efficientnet_conf(
366
+ arch: str,
367
+ **kwargs: Any,
368
+ ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]:
369
+ inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]]
370
+ if arch.startswith("efficientnet_b"):
371
+ bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult"))
372
+ inverted_residual_setting = [
373
+ bneck_conf(1, 3, 1, 32, 16, 1),
374
+ bneck_conf(6, 3, 2, 16, 24, 2),
375
+ bneck_conf(6, 5, 2, 24, 40, 2),
376
+ bneck_conf(6, 3, 2, 40, 80, 3),
377
+ bneck_conf(6, 5, 1, 80, 112, 3),
378
+ bneck_conf(6, 5, 2, 112, 192, 4),
379
+ bneck_conf(6, 3, 1, 192, 320, 1),
380
+ ]
381
+ last_channel = None
382
+ elif arch.startswith("efficientnet_v2_s"):
383
+ inverted_residual_setting = [
384
+ FusedMBConvConfig(1, 3, 1, 24, 24, 2),
385
+ FusedMBConvConfig(4, 3, 2, 24, 48, 4),
386
+ FusedMBConvConfig(4, 3, 2, 48, 64, 4),
387
+ MBConvConfig(4, 3, 2, 64, 128, 6),
388
+ MBConvConfig(6, 3, 1, 128, 160, 9),
389
+ MBConvConfig(6, 3, 2, 160, 256, 15),
390
+ ]
391
+ last_channel = 1280
392
+ elif arch.startswith("efficientnet_v2_m"):
393
+ inverted_residual_setting = [
394
+ FusedMBConvConfig(1, 3, 1, 24, 24, 3),
395
+ FusedMBConvConfig(4, 3, 2, 24, 48, 5),
396
+ FusedMBConvConfig(4, 3, 2, 48, 80, 5),
397
+ MBConvConfig(4, 3, 2, 80, 160, 7),
398
+ MBConvConfig(6, 3, 1, 160, 176, 14),
399
+ MBConvConfig(6, 3, 2, 176, 304, 18),
400
+ MBConvConfig(6, 3, 1, 304, 512, 5),
401
+ ]
402
+ last_channel = 1280
403
+ elif arch.startswith("efficientnet_v2_l"):
404
+ inverted_residual_setting = [
405
+ FusedMBConvConfig(1, 3, 1, 32, 32, 4),
406
+ FusedMBConvConfig(4, 3, 2, 32, 64, 7),
407
+ FusedMBConvConfig(4, 3, 2, 64, 96, 7),
408
+ MBConvConfig(4, 3, 2, 96, 192, 10),
409
+ MBConvConfig(6, 3, 1, 192, 224, 19),
410
+ MBConvConfig(6, 3, 2, 224, 384, 25),
411
+ MBConvConfig(6, 3, 1, 384, 640, 7),
412
+ ]
413
+ last_channel = 1280
414
+ else:
415
+ raise ValueError(f"Unsupported model type {arch}")
416
+
417
+ return inverted_residual_setting, last_channel
418
+
419
+
420
+ _COMMON_META: Dict[str, Any] = {
421
+ "categories": _IMAGENET_CATEGORIES,
422
+ }
423
+
424
+
425
+ _COMMON_META_V1 = {
426
+ **_COMMON_META,
427
+ "min_size": (1, 1),
428
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1",
429
+ }
430
+
431
+
432
+ _COMMON_META_V2 = {
433
+ **_COMMON_META,
434
+ "min_size": (33, 33),
435
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2",
436
+ }
437
+
438
+
439
+ class EfficientNet_B0_Weights(WeightsEnum):
440
+ IMAGENET1K_V1 = Weights(
441
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
442
+ url="https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth",
443
+ transforms=partial(
444
+ ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
445
+ ),
446
+ meta={
447
+ **_COMMON_META_V1,
448
+ "num_params": 5288548,
449
+ "_metrics": {
450
+ "ImageNet-1K": {
451
+ "acc@1": 77.692,
452
+ "acc@5": 93.532,
453
+ }
454
+ },
455
+ "_ops": 0.386,
456
+ "_file_size": 20.451,
457
+ "_docs": """These weights are ported from the original paper.""",
458
+ },
459
+ )
460
+ DEFAULT = IMAGENET1K_V1
461
+
462
+
463
+ class EfficientNet_B1_Weights(WeightsEnum):
464
+ IMAGENET1K_V1 = Weights(
465
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
466
+ url="https://download.pytorch.org/models/efficientnet_b1_rwightman-bac287d4.pth",
467
+ transforms=partial(
468
+ ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
469
+ ),
470
+ meta={
471
+ **_COMMON_META_V1,
472
+ "num_params": 7794184,
473
+ "_metrics": {
474
+ "ImageNet-1K": {
475
+ "acc@1": 78.642,
476
+ "acc@5": 94.186,
477
+ }
478
+ },
479
+ "_ops": 0.687,
480
+ "_file_size": 30.134,
481
+ "_docs": """These weights are ported from the original paper.""",
482
+ },
483
+ )
484
+ IMAGENET1K_V2 = Weights(
485
+ url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
486
+ transforms=partial(
487
+ ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
488
+ ),
489
+ meta={
490
+ **_COMMON_META_V1,
491
+ "num_params": 7794184,
492
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning",
493
+ "_metrics": {
494
+ "ImageNet-1K": {
495
+ "acc@1": 79.838,
496
+ "acc@5": 94.934,
497
+ }
498
+ },
499
+ "_ops": 0.687,
500
+ "_file_size": 30.136,
501
+ "_docs": """
502
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
503
+ `new training recipe
504
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
505
+ """,
506
+ },
507
+ )
508
+ DEFAULT = IMAGENET1K_V2
509
+
510
+
511
+ class EfficientNet_B2_Weights(WeightsEnum):
512
+ IMAGENET1K_V1 = Weights(
513
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
514
+ url="https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth",
515
+ transforms=partial(
516
+ ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
517
+ ),
518
+ meta={
519
+ **_COMMON_META_V1,
520
+ "num_params": 9109994,
521
+ "_metrics": {
522
+ "ImageNet-1K": {
523
+ "acc@1": 80.608,
524
+ "acc@5": 95.310,
525
+ }
526
+ },
527
+ "_ops": 1.088,
528
+ "_file_size": 35.174,
529
+ "_docs": """These weights are ported from the original paper.""",
530
+ },
531
+ )
532
+ DEFAULT = IMAGENET1K_V1
533
+
534
+
535
+ class EfficientNet_B3_Weights(WeightsEnum):
536
+ IMAGENET1K_V1 = Weights(
537
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
538
+ url="https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth",
539
+ transforms=partial(
540
+ ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
541
+ ),
542
+ meta={
543
+ **_COMMON_META_V1,
544
+ "num_params": 12233232,
545
+ "_metrics": {
546
+ "ImageNet-1K": {
547
+ "acc@1": 82.008,
548
+ "acc@5": 96.054,
549
+ }
550
+ },
551
+ "_ops": 1.827,
552
+ "_file_size": 47.184,
553
+ "_docs": """These weights are ported from the original paper.""",
554
+ },
555
+ )
556
+ DEFAULT = IMAGENET1K_V1
557
+
558
+
559
+ class EfficientNet_B4_Weights(WeightsEnum):
560
+ IMAGENET1K_V1 = Weights(
561
+ # Weights ported from https://github.com/rwightman/pytorch-image-models/
562
+ url="https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth",
563
+ transforms=partial(
564
+ ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
565
+ ),
566
+ meta={
567
+ **_COMMON_META_V1,
568
+ "num_params": 19341616,
569
+ "_metrics": {
570
+ "ImageNet-1K": {
571
+ "acc@1": 83.384,
572
+ "acc@5": 96.594,
573
+ }
574
+ },
575
+ "_ops": 4.394,
576
+ "_file_size": 74.489,
577
+ "_docs": """These weights are ported from the original paper.""",
578
+ },
579
+ )
580
+ DEFAULT = IMAGENET1K_V1
581
+
582
+
583
+ class EfficientNet_B5_Weights(WeightsEnum):
584
+ IMAGENET1K_V1 = Weights(
585
+ # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
586
+ url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-1a07897c.pth",
587
+ transforms=partial(
588
+ ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
589
+ ),
590
+ meta={
591
+ **_COMMON_META_V1,
592
+ "num_params": 30389784,
593
+ "_metrics": {
594
+ "ImageNet-1K": {
595
+ "acc@1": 83.444,
596
+ "acc@5": 96.628,
597
+ }
598
+ },
599
+ "_ops": 10.266,
600
+ "_file_size": 116.864,
601
+ "_docs": """These weights are ported from the original paper.""",
602
+ },
603
+ )
604
+ DEFAULT = IMAGENET1K_V1
605
+
606
+
607
+ class EfficientNet_B6_Weights(WeightsEnum):
608
+ IMAGENET1K_V1 = Weights(
609
+ # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
610
+ url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-24a108a5.pth",
611
+ transforms=partial(
612
+ ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
613
+ ),
614
+ meta={
615
+ **_COMMON_META_V1,
616
+ "num_params": 43040704,
617
+ "_metrics": {
618
+ "ImageNet-1K": {
619
+ "acc@1": 84.008,
620
+ "acc@5": 96.916,
621
+ }
622
+ },
623
+ "_ops": 19.068,
624
+ "_file_size": 165.362,
625
+ "_docs": """These weights are ported from the original paper.""",
626
+ },
627
+ )
628
+ DEFAULT = IMAGENET1K_V1
629
+
630
+
631
+ class EfficientNet_B7_Weights(WeightsEnum):
632
+ IMAGENET1K_V1 = Weights(
633
+ # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/
634
+ url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-c5b4e57e.pth",
635
+ transforms=partial(
636
+ ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
637
+ ),
638
+ meta={
639
+ **_COMMON_META_V1,
640
+ "num_params": 66347960,
641
+ "_metrics": {
642
+ "ImageNet-1K": {
643
+ "acc@1": 84.122,
644
+ "acc@5": 96.908,
645
+ }
646
+ },
647
+ "_ops": 37.746,
648
+ "_file_size": 254.675,
649
+ "_docs": """These weights are ported from the original paper.""",
650
+ },
651
+ )
652
+ DEFAULT = IMAGENET1K_V1
653
+
654
+
655
+ class EfficientNet_V2_S_Weights(WeightsEnum):
656
+ IMAGENET1K_V1 = Weights(
657
+ url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
658
+ transforms=partial(
659
+ ImageClassification,
660
+ crop_size=384,
661
+ resize_size=384,
662
+ interpolation=InterpolationMode.BILINEAR,
663
+ ),
664
+ meta={
665
+ **_COMMON_META_V2,
666
+ "num_params": 21458488,
667
+ "_metrics": {
668
+ "ImageNet-1K": {
669
+ "acc@1": 84.228,
670
+ "acc@5": 96.878,
671
+ }
672
+ },
673
+ "_ops": 8.366,
674
+ "_file_size": 82.704,
675
+ "_docs": """
676
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
677
+ `new training recipe
678
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
679
+ """,
680
+ },
681
+ )
682
+ DEFAULT = IMAGENET1K_V1
683
+
684
+
685
+ class EfficientNet_V2_M_Weights(WeightsEnum):
686
+ IMAGENET1K_V1 = Weights(
687
+ url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
688
+ transforms=partial(
689
+ ImageClassification,
690
+ crop_size=480,
691
+ resize_size=480,
692
+ interpolation=InterpolationMode.BILINEAR,
693
+ ),
694
+ meta={
695
+ **_COMMON_META_V2,
696
+ "num_params": 54139356,
697
+ "_metrics": {
698
+ "ImageNet-1K": {
699
+ "acc@1": 85.112,
700
+ "acc@5": 97.156,
701
+ }
702
+ },
703
+ "_ops": 24.582,
704
+ "_file_size": 208.01,
705
+ "_docs": """
706
+ These weights improve upon the results of the original paper by using a modified version of TorchVision's
707
+ `new training recipe
708
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
709
+ """,
710
+ },
711
+ )
712
+ DEFAULT = IMAGENET1K_V1
713
+
714
+
715
+ class EfficientNet_V2_L_Weights(WeightsEnum):
716
+ # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2
717
+ IMAGENET1K_V1 = Weights(
718
+ url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
719
+ transforms=partial(
720
+ ImageClassification,
721
+ crop_size=480,
722
+ resize_size=480,
723
+ interpolation=InterpolationMode.BICUBIC,
724
+ mean=(0.5, 0.5, 0.5),
725
+ std=(0.5, 0.5, 0.5),
726
+ ),
727
+ meta={
728
+ **_COMMON_META_V2,
729
+ "num_params": 118515272,
730
+ "_metrics": {
731
+ "ImageNet-1K": {
732
+ "acc@1": 85.808,
733
+ "acc@5": 97.788,
734
+ }
735
+ },
736
+ "_ops": 56.08,
737
+ "_file_size": 454.573,
738
+ "_docs": """These weights are ported from the original paper.""",
739
+ },
740
+ )
741
+ DEFAULT = IMAGENET1K_V1
742
+
743
+
744
+ @register_model()
745
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1))
746
+ def efficientnet_b0(
747
+ *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any
748
+ ) -> EfficientNet:
749
+ """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
750
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
751
+
752
+ Args:
753
+ weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The
754
+ pretrained weights to use. See
755
+ :class:`~torchvision.models.EfficientNet_B0_Weights` below for
756
+ more details, and possible values. By default, no pre-trained
757
+ weights are used.
758
+ progress (bool, optional): If True, displays a progress bar of the
759
+ download to stderr. Default is True.
760
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
761
+ base class. Please refer to the `source code
762
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
763
+ for more details about this class.
764
+ .. autoclass:: torchvision.models.EfficientNet_B0_Weights
765
+ :members:
766
+ """
767
+ weights = EfficientNet_B0_Weights.verify(weights)
768
+
769
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0)
770
+ return _efficientnet(
771
+ inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
772
+ )
773
+
774
+
775
+ @register_model()
776
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1))
777
+ def efficientnet_b1(
778
+ *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any
779
+ ) -> EfficientNet:
780
+ """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
781
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
782
+
783
+ Args:
784
+ weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The
785
+ pretrained weights to use. See
786
+ :class:`~torchvision.models.EfficientNet_B1_Weights` below for
787
+ more details, and possible values. By default, no pre-trained
788
+ weights are used.
789
+ progress (bool, optional): If True, displays a progress bar of the
790
+ download to stderr. Default is True.
791
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
792
+ base class. Please refer to the `source code
793
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
794
+ for more details about this class.
795
+ .. autoclass:: torchvision.models.EfficientNet_B1_Weights
796
+ :members:
797
+ """
798
+ weights = EfficientNet_B1_Weights.verify(weights)
799
+
800
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1)
801
+ return _efficientnet(
802
+ inverted_residual_setting, kwargs.pop("dropout", 0.2), last_channel, weights, progress, **kwargs
803
+ )
804
+
805
+
806
+ @register_model()
807
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1))
808
+ def efficientnet_b2(
809
+ *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any
810
+ ) -> EfficientNet:
811
+ """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
812
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
813
+
814
+ Args:
815
+ weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The
816
+ pretrained weights to use. See
817
+ :class:`~torchvision.models.EfficientNet_B2_Weights` below for
818
+ more details, and possible values. By default, no pre-trained
819
+ weights are used.
820
+ progress (bool, optional): If True, displays a progress bar of the
821
+ download to stderr. Default is True.
822
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
823
+ base class. Please refer to the `source code
824
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
825
+ for more details about this class.
826
+ .. autoclass:: torchvision.models.EfficientNet_B2_Weights
827
+ :members:
828
+ """
829
+ weights = EfficientNet_B2_Weights.verify(weights)
830
+
831
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2)
832
+ return _efficientnet(
833
+ inverted_residual_setting, kwargs.pop("dropout", 0.3), last_channel, weights, progress, **kwargs
834
+ )
835
+
836
+
837
+ @register_model()
838
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1))
839
+ def efficientnet_b3(
840
+ *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any
841
+ ) -> EfficientNet:
842
+ """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
843
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
844
+
845
+ Args:
846
+ weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The
847
+ pretrained weights to use. See
848
+ :class:`~torchvision.models.EfficientNet_B3_Weights` below for
849
+ more details, and possible values. By default, no pre-trained
850
+ weights are used.
851
+ progress (bool, optional): If True, displays a progress bar of the
852
+ download to stderr. Default is True.
853
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
854
+ base class. Please refer to the `source code
855
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
856
+ for more details about this class.
857
+ .. autoclass:: torchvision.models.EfficientNet_B3_Weights
858
+ :members:
859
+ """
860
+ weights = EfficientNet_B3_Weights.verify(weights)
861
+
862
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4)
863
+ return _efficientnet(
864
+ inverted_residual_setting,
865
+ kwargs.pop("dropout", 0.3),
866
+ last_channel,
867
+ weights,
868
+ progress,
869
+ **kwargs,
870
+ )
871
+
872
+
873
+ @register_model()
874
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1))
875
+ def efficientnet_b4(
876
+ *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any
877
+ ) -> EfficientNet:
878
+ """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
879
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
880
+
881
+ Args:
882
+ weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The
883
+ pretrained weights to use. See
884
+ :class:`~torchvision.models.EfficientNet_B4_Weights` below for
885
+ more details, and possible values. By default, no pre-trained
886
+ weights are used.
887
+ progress (bool, optional): If True, displays a progress bar of the
888
+ download to stderr. Default is True.
889
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
890
+ base class. Please refer to the `source code
891
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
892
+ for more details about this class.
893
+ .. autoclass:: torchvision.models.EfficientNet_B4_Weights
894
+ :members:
895
+ """
896
+ weights = EfficientNet_B4_Weights.verify(weights)
897
+
898
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8)
899
+ return _efficientnet(
900
+ inverted_residual_setting,
901
+ kwargs.pop("dropout", 0.4),
902
+ last_channel,
903
+ weights,
904
+ progress,
905
+ **kwargs,
906
+ )
907
+
908
+
909
+ @register_model()
910
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1))
911
+ def efficientnet_b5(
912
+ *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any
913
+ ) -> EfficientNet:
914
+ """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
915
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
916
+
917
+ Args:
918
+ weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The
919
+ pretrained weights to use. See
920
+ :class:`~torchvision.models.EfficientNet_B5_Weights` below for
921
+ more details, and possible values. By default, no pre-trained
922
+ weights are used.
923
+ progress (bool, optional): If True, displays a progress bar of the
924
+ download to stderr. Default is True.
925
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
926
+ base class. Please refer to the `source code
927
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
928
+ for more details about this class.
929
+ .. autoclass:: torchvision.models.EfficientNet_B5_Weights
930
+ :members:
931
+ """
932
+ weights = EfficientNet_B5_Weights.verify(weights)
933
+
934
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2)
935
+ return _efficientnet(
936
+ inverted_residual_setting,
937
+ kwargs.pop("dropout", 0.4),
938
+ last_channel,
939
+ weights,
940
+ progress,
941
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
942
+ **kwargs,
943
+ )
944
+
945
+
946
+ @register_model()
947
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1))
948
+ def efficientnet_b6(
949
+ *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any
950
+ ) -> EfficientNet:
951
+ """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
952
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
953
+
954
+ Args:
955
+ weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The
956
+ pretrained weights to use. See
957
+ :class:`~torchvision.models.EfficientNet_B6_Weights` below for
958
+ more details, and possible values. By default, no pre-trained
959
+ weights are used.
960
+ progress (bool, optional): If True, displays a progress bar of the
961
+ download to stderr. Default is True.
962
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
963
+ base class. Please refer to the `source code
964
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
965
+ for more details about this class.
966
+ .. autoclass:: torchvision.models.EfficientNet_B6_Weights
967
+ :members:
968
+ """
969
+ weights = EfficientNet_B6_Weights.verify(weights)
970
+
971
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6)
972
+ return _efficientnet(
973
+ inverted_residual_setting,
974
+ kwargs.pop("dropout", 0.5),
975
+ last_channel,
976
+ weights,
977
+ progress,
978
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
979
+ **kwargs,
980
+ )
981
+
982
+
983
+ @register_model()
984
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1))
985
+ def efficientnet_b7(
986
+ *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any
987
+ ) -> EfficientNet:
988
+ """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional
989
+ Neural Networks <https://arxiv.org/abs/1905.11946>`_ paper.
990
+
991
+ Args:
992
+ weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The
993
+ pretrained weights to use. See
994
+ :class:`~torchvision.models.EfficientNet_B7_Weights` below for
995
+ more details, and possible values. By default, no pre-trained
996
+ weights are used.
997
+ progress (bool, optional): If True, displays a progress bar of the
998
+ download to stderr. Default is True.
999
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
1000
+ base class. Please refer to the `source code
1001
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
1002
+ for more details about this class.
1003
+ .. autoclass:: torchvision.models.EfficientNet_B7_Weights
1004
+ :members:
1005
+ """
1006
+ weights = EfficientNet_B7_Weights.verify(weights)
1007
+
1008
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1)
1009
+ return _efficientnet(
1010
+ inverted_residual_setting,
1011
+ kwargs.pop("dropout", 0.5),
1012
+ last_channel,
1013
+ weights,
1014
+ progress,
1015
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01),
1016
+ **kwargs,
1017
+ )
1018
+
1019
+
1020
+ @register_model()
1021
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1))
1022
+ def efficientnet_v2_s(
1023
+ *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any
1024
+ ) -> EfficientNet:
1025
+ """
1026
+ Constructs an EfficientNetV2-S architecture from
1027
+ `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
1028
+
1029
+ Args:
1030
+ weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The
1031
+ pretrained weights to use. See
1032
+ :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for
1033
+ more details, and possible values. By default, no pre-trained
1034
+ weights are used.
1035
+ progress (bool, optional): If True, displays a progress bar of the
1036
+ download to stderr. Default is True.
1037
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
1038
+ base class. Please refer to the `source code
1039
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
1040
+ for more details about this class.
1041
+ .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights
1042
+ :members:
1043
+ """
1044
+ weights = EfficientNet_V2_S_Weights.verify(weights)
1045
+
1046
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s")
1047
+ return _efficientnet(
1048
+ inverted_residual_setting,
1049
+ kwargs.pop("dropout", 0.2),
1050
+ last_channel,
1051
+ weights,
1052
+ progress,
1053
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
1054
+ **kwargs,
1055
+ )
1056
+
1057
+
1058
+ @register_model()
1059
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1))
1060
+ def efficientnet_v2_m(
1061
+ *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any
1062
+ ) -> EfficientNet:
1063
+ """
1064
+ Constructs an EfficientNetV2-M architecture from
1065
+ `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
1066
+
1067
+ Args:
1068
+ weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The
1069
+ pretrained weights to use. See
1070
+ :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for
1071
+ more details, and possible values. By default, no pre-trained
1072
+ weights are used.
1073
+ progress (bool, optional): If True, displays a progress bar of the
1074
+ download to stderr. Default is True.
1075
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
1076
+ base class. Please refer to the `source code
1077
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
1078
+ for more details about this class.
1079
+ .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights
1080
+ :members:
1081
+ """
1082
+ weights = EfficientNet_V2_M_Weights.verify(weights)
1083
+
1084
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m")
1085
+ return _efficientnet(
1086
+ inverted_residual_setting,
1087
+ kwargs.pop("dropout", 0.3),
1088
+ last_channel,
1089
+ weights,
1090
+ progress,
1091
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
1092
+ **kwargs,
1093
+ )
1094
+
1095
+
1096
+ @register_model()
1097
+ @handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1))
1098
+ def efficientnet_v2_l(
1099
+ *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any
1100
+ ) -> EfficientNet:
1101
+ """
1102
+ Constructs an EfficientNetV2-L architecture from
1103
+ `EfficientNetV2: Smaller Models and Faster Training <https://arxiv.org/abs/2104.00298>`_.
1104
+
1105
+ Args:
1106
+ weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The
1107
+ pretrained weights to use. See
1108
+ :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for
1109
+ more details, and possible values. By default, no pre-trained
1110
+ weights are used.
1111
+ progress (bool, optional): If True, displays a progress bar of the
1112
+ download to stderr. Default is True.
1113
+ **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet``
1114
+ base class. Please refer to the `source code
1115
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/efficientnet.py>`_
1116
+ for more details about this class.
1117
+ .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights
1118
+ :members:
1119
+ """
1120
+ weights = EfficientNet_V2_L_Weights.verify(weights)
1121
+
1122
+ inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l")
1123
+ return _efficientnet(
1124
+ inverted_residual_setting,
1125
+ kwargs.pop("dropout", 0.4),
1126
+ last_channel,
1127
+ weights,
1128
+ progress,
1129
+ norm_layer=partial(nn.BatchNorm2d, eps=1e-03),
1130
+ **kwargs,
1131
+ )
.venv/lib/python3.11/site-packages/torchvision/models/googlenet.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from collections import namedtuple
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+
11
+ from ..transforms._presets import ImageClassification
12
+ from ..utils import _log_api_usage_once
13
+ from ._api import register_model, Weights, WeightsEnum
14
+ from ._meta import _IMAGENET_CATEGORIES
15
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
16
+
17
+
18
+ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weights", "googlenet"]
19
+
20
+
21
+ GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"])
22
+ GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]}
23
+
24
+ # Script annotations failed with _GoogleNetOutputs = namedtuple ...
25
+ # _GoogLeNetOutputs set here for backwards compat
26
+ _GoogLeNetOutputs = GoogLeNetOutputs
27
+
28
+
29
+ class GoogLeNet(nn.Module):
30
+ __constants__ = ["aux_logits", "transform_input"]
31
+
32
+ def __init__(
33
+ self,
34
+ num_classes: int = 1000,
35
+ aux_logits: bool = True,
36
+ transform_input: bool = False,
37
+ init_weights: Optional[bool] = None,
38
+ blocks: Optional[List[Callable[..., nn.Module]]] = None,
39
+ dropout: float = 0.2,
40
+ dropout_aux: float = 0.7,
41
+ ) -> None:
42
+ super().__init__()
43
+ _log_api_usage_once(self)
44
+ if blocks is None:
45
+ blocks = [BasicConv2d, Inception, InceptionAux]
46
+ if init_weights is None:
47
+ warnings.warn(
48
+ "The default weight initialization of GoogleNet will be changed in future releases of "
49
+ "torchvision. If you wish to keep the old behavior (which leads to long initialization times"
50
+ " due to scipy/scipy#11299), please set init_weights=True.",
51
+ FutureWarning,
52
+ )
53
+ init_weights = True
54
+ if len(blocks) != 3:
55
+ raise ValueError(f"blocks length should be 3 instead of {len(blocks)}")
56
+ conv_block = blocks[0]
57
+ inception_block = blocks[1]
58
+ inception_aux_block = blocks[2]
59
+
60
+ self.aux_logits = aux_logits
61
+ self.transform_input = transform_input
62
+
63
+ self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3)
64
+ self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
65
+ self.conv2 = conv_block(64, 64, kernel_size=1)
66
+ self.conv3 = conv_block(64, 192, kernel_size=3, padding=1)
67
+ self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
68
+
69
+ self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32)
70
+ self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)
71
+ self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
72
+
73
+ self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)
74
+ self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)
75
+ self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)
76
+ self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)
77
+ self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)
78
+ self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
79
+
80
+ self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)
81
+ self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)
82
+
83
+ if aux_logits:
84
+ self.aux1 = inception_aux_block(512, num_classes, dropout=dropout_aux)
85
+ self.aux2 = inception_aux_block(528, num_classes, dropout=dropout_aux)
86
+ else:
87
+ self.aux1 = None # type: ignore[assignment]
88
+ self.aux2 = None # type: ignore[assignment]
89
+
90
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
91
+ self.dropout = nn.Dropout(p=dropout)
92
+ self.fc = nn.Linear(1024, num_classes)
93
+
94
+ if init_weights:
95
+ for m in self.modules():
96
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
97
+ torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=0.01, a=-2, b=2)
98
+ elif isinstance(m, nn.BatchNorm2d):
99
+ nn.init.constant_(m.weight, 1)
100
+ nn.init.constant_(m.bias, 0)
101
+
102
+ def _transform_input(self, x: Tensor) -> Tensor:
103
+ if self.transform_input:
104
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
105
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
106
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
107
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
108
+ return x
109
+
110
+ def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
111
+ # N x 3 x 224 x 224
112
+ x = self.conv1(x)
113
+ # N x 64 x 112 x 112
114
+ x = self.maxpool1(x)
115
+ # N x 64 x 56 x 56
116
+ x = self.conv2(x)
117
+ # N x 64 x 56 x 56
118
+ x = self.conv3(x)
119
+ # N x 192 x 56 x 56
120
+ x = self.maxpool2(x)
121
+
122
+ # N x 192 x 28 x 28
123
+ x = self.inception3a(x)
124
+ # N x 256 x 28 x 28
125
+ x = self.inception3b(x)
126
+ # N x 480 x 28 x 28
127
+ x = self.maxpool3(x)
128
+ # N x 480 x 14 x 14
129
+ x = self.inception4a(x)
130
+ # N x 512 x 14 x 14
131
+ aux1: Optional[Tensor] = None
132
+ if self.aux1 is not None:
133
+ if self.training:
134
+ aux1 = self.aux1(x)
135
+
136
+ x = self.inception4b(x)
137
+ # N x 512 x 14 x 14
138
+ x = self.inception4c(x)
139
+ # N x 512 x 14 x 14
140
+ x = self.inception4d(x)
141
+ # N x 528 x 14 x 14
142
+ aux2: Optional[Tensor] = None
143
+ if self.aux2 is not None:
144
+ if self.training:
145
+ aux2 = self.aux2(x)
146
+
147
+ x = self.inception4e(x)
148
+ # N x 832 x 14 x 14
149
+ x = self.maxpool4(x)
150
+ # N x 832 x 7 x 7
151
+ x = self.inception5a(x)
152
+ # N x 832 x 7 x 7
153
+ x = self.inception5b(x)
154
+ # N x 1024 x 7 x 7
155
+
156
+ x = self.avgpool(x)
157
+ # N x 1024 x 1 x 1
158
+ x = torch.flatten(x, 1)
159
+ # N x 1024
160
+ x = self.dropout(x)
161
+ x = self.fc(x)
162
+ # N x 1000 (num_classes)
163
+ return x, aux2, aux1
164
+
165
+ @torch.jit.unused
166
+ def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:
167
+ if self.training and self.aux_logits:
168
+ return _GoogLeNetOutputs(x, aux2, aux1)
169
+ else:
170
+ return x # type: ignore[return-value]
171
+
172
+ def forward(self, x: Tensor) -> GoogLeNetOutputs:
173
+ x = self._transform_input(x)
174
+ x, aux1, aux2 = self._forward(x)
175
+ aux_defined = self.training and self.aux_logits
176
+ if torch.jit.is_scripting():
177
+ if not aux_defined:
178
+ warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")
179
+ return GoogLeNetOutputs(x, aux2, aux1)
180
+ else:
181
+ return self.eager_outputs(x, aux2, aux1)
182
+
183
+
184
+ class Inception(nn.Module):
185
+ def __init__(
186
+ self,
187
+ in_channels: int,
188
+ ch1x1: int,
189
+ ch3x3red: int,
190
+ ch3x3: int,
191
+ ch5x5red: int,
192
+ ch5x5: int,
193
+ pool_proj: int,
194
+ conv_block: Optional[Callable[..., nn.Module]] = None,
195
+ ) -> None:
196
+ super().__init__()
197
+ if conv_block is None:
198
+ conv_block = BasicConv2d
199
+ self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)
200
+
201
+ self.branch2 = nn.Sequential(
202
+ conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)
203
+ )
204
+
205
+ self.branch3 = nn.Sequential(
206
+ conv_block(in_channels, ch5x5red, kernel_size=1),
207
+ # Here, kernel_size=3 instead of kernel_size=5 is a known bug.
208
+ # Please see https://github.com/pytorch/vision/issues/906 for details.
209
+ conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1),
210
+ )
211
+
212
+ self.branch4 = nn.Sequential(
213
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
214
+ conv_block(in_channels, pool_proj, kernel_size=1),
215
+ )
216
+
217
+ def _forward(self, x: Tensor) -> List[Tensor]:
218
+ branch1 = self.branch1(x)
219
+ branch2 = self.branch2(x)
220
+ branch3 = self.branch3(x)
221
+ branch4 = self.branch4(x)
222
+
223
+ outputs = [branch1, branch2, branch3, branch4]
224
+ return outputs
225
+
226
+ def forward(self, x: Tensor) -> Tensor:
227
+ outputs = self._forward(x)
228
+ return torch.cat(outputs, 1)
229
+
230
+
231
+ class InceptionAux(nn.Module):
232
+ def __init__(
233
+ self,
234
+ in_channels: int,
235
+ num_classes: int,
236
+ conv_block: Optional[Callable[..., nn.Module]] = None,
237
+ dropout: float = 0.7,
238
+ ) -> None:
239
+ super().__init__()
240
+ if conv_block is None:
241
+ conv_block = BasicConv2d
242
+ self.conv = conv_block(in_channels, 128, kernel_size=1)
243
+
244
+ self.fc1 = nn.Linear(2048, 1024)
245
+ self.fc2 = nn.Linear(1024, num_classes)
246
+ self.dropout = nn.Dropout(p=dropout)
247
+
248
+ def forward(self, x: Tensor) -> Tensor:
249
+ # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
250
+ x = F.adaptive_avg_pool2d(x, (4, 4))
251
+ # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
252
+ x = self.conv(x)
253
+ # N x 128 x 4 x 4
254
+ x = torch.flatten(x, 1)
255
+ # N x 2048
256
+ x = F.relu(self.fc1(x), inplace=True)
257
+ # N x 1024
258
+ x = self.dropout(x)
259
+ # N x 1024
260
+ x = self.fc2(x)
261
+ # N x 1000 (num_classes)
262
+
263
+ return x
264
+
265
+
266
+ class BasicConv2d(nn.Module):
267
+ def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None:
268
+ super().__init__()
269
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
270
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
271
+
272
+ def forward(self, x: Tensor) -> Tensor:
273
+ x = self.conv(x)
274
+ x = self.bn(x)
275
+ return F.relu(x, inplace=True)
276
+
277
+
278
+ class GoogLeNet_Weights(WeightsEnum):
279
+ IMAGENET1K_V1 = Weights(
280
+ url="https://download.pytorch.org/models/googlenet-1378be20.pth",
281
+ transforms=partial(ImageClassification, crop_size=224),
282
+ meta={
283
+ "num_params": 6624904,
284
+ "min_size": (15, 15),
285
+ "categories": _IMAGENET_CATEGORIES,
286
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#googlenet",
287
+ "_metrics": {
288
+ "ImageNet-1K": {
289
+ "acc@1": 69.778,
290
+ "acc@5": 89.530,
291
+ }
292
+ },
293
+ "_ops": 1.498,
294
+ "_file_size": 49.731,
295
+ "_docs": """These weights are ported from the original paper.""",
296
+ },
297
+ )
298
+ DEFAULT = IMAGENET1K_V1
299
+
300
+
301
+ @register_model()
302
+ @handle_legacy_interface(weights=("pretrained", GoogLeNet_Weights.IMAGENET1K_V1))
303
+ def googlenet(*, weights: Optional[GoogLeNet_Weights] = None, progress: bool = True, **kwargs: Any) -> GoogLeNet:
304
+ """GoogLeNet (Inception v1) model architecture from
305
+ `Going Deeper with Convolutions <http://arxiv.org/abs/1409.4842>`_.
306
+
307
+ Args:
308
+ weights (:class:`~torchvision.models.GoogLeNet_Weights`, optional): The
309
+ pretrained weights for the model. See
310
+ :class:`~torchvision.models.GoogLeNet_Weights` below for
311
+ more details, and possible values. By default, no pre-trained
312
+ weights are used.
313
+ progress (bool, optional): If True, displays a progress bar of the
314
+ download to stderr. Default is True.
315
+ **kwargs: parameters passed to the ``torchvision.models.GoogLeNet``
316
+ base class. Please refer to the `source code
317
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/googlenet.py>`_
318
+ for more details about this class.
319
+ .. autoclass:: torchvision.models.GoogLeNet_Weights
320
+ :members:
321
+ """
322
+ weights = GoogLeNet_Weights.verify(weights)
323
+
324
+ original_aux_logits = kwargs.get("aux_logits", False)
325
+ if weights is not None:
326
+ if "transform_input" not in kwargs:
327
+ _ovewrite_named_param(kwargs, "transform_input", True)
328
+ _ovewrite_named_param(kwargs, "aux_logits", True)
329
+ _ovewrite_named_param(kwargs, "init_weights", False)
330
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
331
+
332
+ model = GoogLeNet(**kwargs)
333
+
334
+ if weights is not None:
335
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
336
+ if not original_aux_logits:
337
+ model.aux_logits = False
338
+ model.aux1 = None # type: ignore[assignment]
339
+ model.aux2 = None # type: ignore[assignment]
340
+ else:
341
+ warnings.warn(
342
+ "auxiliary heads in the pretrained googlenet model are NOT pretrained, so make sure to train them"
343
+ )
344
+
345
+ return model
.venv/lib/python3.11/site-packages/torchvision/models/maxvit.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+ from typing import Any, Callable, List, Optional, Sequence, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn, Tensor
10
+ from torchvision.models._api import register_model, Weights, WeightsEnum
11
+ from torchvision.models._meta import _IMAGENET_CATEGORIES
12
+ from torchvision.models._utils import _ovewrite_named_param, handle_legacy_interface
13
+ from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation
14
+ from torchvision.ops.stochastic_depth import StochasticDepth
15
+ from torchvision.transforms._presets import ImageClassification, InterpolationMode
16
+ from torchvision.utils import _log_api_usage_once
17
+
18
+ __all__ = [
19
+ "MaxVit",
20
+ "MaxVit_T_Weights",
21
+ "maxvit_t",
22
+ ]
23
+
24
+
25
+ def _get_conv_output_shape(input_size: Tuple[int, int], kernel_size: int, stride: int, padding: int) -> Tuple[int, int]:
26
+ return (
27
+ (input_size[0] - kernel_size + 2 * padding) // stride + 1,
28
+ (input_size[1] - kernel_size + 2 * padding) // stride + 1,
29
+ )
30
+
31
+
32
+ def _make_block_input_shapes(input_size: Tuple[int, int], n_blocks: int) -> List[Tuple[int, int]]:
33
+ """Util function to check that the input size is correct for a MaxVit configuration."""
34
+ shapes = []
35
+ block_input_shape = _get_conv_output_shape(input_size, 3, 2, 1)
36
+ for _ in range(n_blocks):
37
+ block_input_shape = _get_conv_output_shape(block_input_shape, 3, 2, 1)
38
+ shapes.append(block_input_shape)
39
+ return shapes
40
+
41
+
42
+ def _get_relative_position_index(height: int, width: int) -> torch.Tensor:
43
+ coords = torch.stack(torch.meshgrid([torch.arange(height), torch.arange(width)]))
44
+ coords_flat = torch.flatten(coords, 1)
45
+ relative_coords = coords_flat[:, :, None] - coords_flat[:, None, :]
46
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
47
+ relative_coords[:, :, 0] += height - 1
48
+ relative_coords[:, :, 1] += width - 1
49
+ relative_coords[:, :, 0] *= 2 * width - 1
50
+ return relative_coords.sum(-1)
51
+
52
+
53
+ class MBConv(nn.Module):
54
+ """MBConv: Mobile Inverted Residual Bottleneck.
55
+
56
+ Args:
57
+ in_channels (int): Number of input channels.
58
+ out_channels (int): Number of output channels.
59
+ expansion_ratio (float): Expansion ratio in the bottleneck.
60
+ squeeze_ratio (float): Squeeze ratio in the SE Layer.
61
+ stride (int): Stride of the depthwise convolution.
62
+ activation_layer (Callable[..., nn.Module]): Activation function.
63
+ norm_layer (Callable[..., nn.Module]): Normalization function.
64
+ p_stochastic_dropout (float): Probability of stochastic depth.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ expansion_ratio: float,
72
+ squeeze_ratio: float,
73
+ stride: int,
74
+ activation_layer: Callable[..., nn.Module],
75
+ norm_layer: Callable[..., nn.Module],
76
+ p_stochastic_dropout: float = 0.0,
77
+ ) -> None:
78
+ super().__init__()
79
+
80
+ proj: Sequence[nn.Module]
81
+ self.proj: nn.Module
82
+
83
+ should_proj = stride != 1 or in_channels != out_channels
84
+ if should_proj:
85
+ proj = [nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=True)]
86
+ if stride == 2:
87
+ proj = [nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)] + proj # type: ignore
88
+ self.proj = nn.Sequential(*proj)
89
+ else:
90
+ self.proj = nn.Identity() # type: ignore
91
+
92
+ mid_channels = int(out_channels * expansion_ratio)
93
+ sqz_channels = int(out_channels * squeeze_ratio)
94
+
95
+ if p_stochastic_dropout:
96
+ self.stochastic_depth = StochasticDepth(p_stochastic_dropout, mode="row") # type: ignore
97
+ else:
98
+ self.stochastic_depth = nn.Identity() # type: ignore
99
+
100
+ _layers = OrderedDict()
101
+ _layers["pre_norm"] = norm_layer(in_channels)
102
+ _layers["conv_a"] = Conv2dNormActivation(
103
+ in_channels,
104
+ mid_channels,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0,
108
+ activation_layer=activation_layer,
109
+ norm_layer=norm_layer,
110
+ inplace=None,
111
+ )
112
+ _layers["conv_b"] = Conv2dNormActivation(
113
+ mid_channels,
114
+ mid_channels,
115
+ kernel_size=3,
116
+ stride=stride,
117
+ padding=1,
118
+ activation_layer=activation_layer,
119
+ norm_layer=norm_layer,
120
+ groups=mid_channels,
121
+ inplace=None,
122
+ )
123
+ _layers["squeeze_excitation"] = SqueezeExcitation(mid_channels, sqz_channels, activation=nn.SiLU)
124
+ _layers["conv_c"] = nn.Conv2d(in_channels=mid_channels, out_channels=out_channels, kernel_size=1, bias=True)
125
+
126
+ self.layers = nn.Sequential(_layers)
127
+
128
+ def forward(self, x: Tensor) -> Tensor:
129
+ """
130
+ Args:
131
+ x (Tensor): Input tensor with expected layout of [B, C, H, W].
132
+ Returns:
133
+ Tensor: Output tensor with expected layout of [B, C, H / stride, W / stride].
134
+ """
135
+ res = self.proj(x)
136
+ x = self.stochastic_depth(self.layers(x))
137
+ return res + x
138
+
139
+
140
+ class RelativePositionalMultiHeadAttention(nn.Module):
141
+ """Relative Positional Multi-Head Attention.
142
+
143
+ Args:
144
+ feat_dim (int): Number of input features.
145
+ head_dim (int): Number of features per head.
146
+ max_seq_len (int): Maximum sequence length.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ feat_dim: int,
152
+ head_dim: int,
153
+ max_seq_len: int,
154
+ ) -> None:
155
+ super().__init__()
156
+
157
+ if feat_dim % head_dim != 0:
158
+ raise ValueError(f"feat_dim: {feat_dim} must be divisible by head_dim: {head_dim}")
159
+
160
+ self.n_heads = feat_dim // head_dim
161
+ self.head_dim = head_dim
162
+ self.size = int(math.sqrt(max_seq_len))
163
+ self.max_seq_len = max_seq_len
164
+
165
+ self.to_qkv = nn.Linear(feat_dim, self.n_heads * self.head_dim * 3)
166
+ self.scale_factor = feat_dim**-0.5
167
+
168
+ self.merge = nn.Linear(self.head_dim * self.n_heads, feat_dim)
169
+ self.relative_position_bias_table = nn.parameter.Parameter(
170
+ torch.empty(((2 * self.size - 1) * (2 * self.size - 1), self.n_heads), dtype=torch.float32),
171
+ )
172
+
173
+ self.register_buffer("relative_position_index", _get_relative_position_index(self.size, self.size))
174
+ # initialize with truncated normal the bias
175
+ torch.nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
176
+
177
+ def get_relative_positional_bias(self) -> torch.Tensor:
178
+ bias_index = self.relative_position_index.view(-1) # type: ignore
179
+ relative_bias = self.relative_position_bias_table[bias_index].view(self.max_seq_len, self.max_seq_len, -1) # type: ignore
180
+ relative_bias = relative_bias.permute(2, 0, 1).contiguous()
181
+ return relative_bias.unsqueeze(0)
182
+
183
+ def forward(self, x: Tensor) -> Tensor:
184
+ """
185
+ Args:
186
+ x (Tensor): Input tensor with expected layout of [B, G, P, D].
187
+ Returns:
188
+ Tensor: Output tensor with expected layout of [B, G, P, D].
189
+ """
190
+ B, G, P, D = x.shape
191
+ H, DH = self.n_heads, self.head_dim
192
+
193
+ qkv = self.to_qkv(x)
194
+ q, k, v = torch.chunk(qkv, 3, dim=-1)
195
+
196
+ q = q.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
197
+ k = k.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
198
+ v = v.reshape(B, G, P, H, DH).permute(0, 1, 3, 2, 4)
199
+
200
+ k = k * self.scale_factor
201
+ dot_prod = torch.einsum("B G H I D, B G H J D -> B G H I J", q, k)
202
+ pos_bias = self.get_relative_positional_bias()
203
+
204
+ dot_prod = F.softmax(dot_prod + pos_bias, dim=-1)
205
+
206
+ out = torch.einsum("B G H I J, B G H J D -> B G H I D", dot_prod, v)
207
+ out = out.permute(0, 1, 3, 2, 4).reshape(B, G, P, D)
208
+
209
+ out = self.merge(out)
210
+ return out
211
+
212
+
213
+ class SwapAxes(nn.Module):
214
+ """Permute the axes of a tensor."""
215
+
216
+ def __init__(self, a: int, b: int) -> None:
217
+ super().__init__()
218
+ self.a = a
219
+ self.b = b
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ res = torch.swapaxes(x, self.a, self.b)
223
+ return res
224
+
225
+
226
+ class WindowPartition(nn.Module):
227
+ """
228
+ Partition the input tensor into non-overlapping windows.
229
+ """
230
+
231
+ def __init__(self) -> None:
232
+ super().__init__()
233
+
234
+ def forward(self, x: Tensor, p: int) -> Tensor:
235
+ """
236
+ Args:
237
+ x (Tensor): Input tensor with expected layout of [B, C, H, W].
238
+ p (int): Number of partitions.
239
+ Returns:
240
+ Tensor: Output tensor with expected layout of [B, H/P, W/P, P*P, C].
241
+ """
242
+ B, C, H, W = x.shape
243
+ P = p
244
+ # chunk up H and W dimensions
245
+ x = x.reshape(B, C, H // P, P, W // P, P)
246
+ x = x.permute(0, 2, 4, 3, 5, 1)
247
+ # colapse P * P dimension
248
+ x = x.reshape(B, (H // P) * (W // P), P * P, C)
249
+ return x
250
+
251
+
252
+ class WindowDepartition(nn.Module):
253
+ """
254
+ Departition the input tensor of non-overlapping windows into a feature volume of layout [B, C, H, W].
255
+ """
256
+
257
+ def __init__(self) -> None:
258
+ super().__init__()
259
+
260
+ def forward(self, x: Tensor, p: int, h_partitions: int, w_partitions: int) -> Tensor:
261
+ """
262
+ Args:
263
+ x (Tensor): Input tensor with expected layout of [B, (H/P * W/P), P*P, C].
264
+ p (int): Number of partitions.
265
+ h_partitions (int): Number of vertical partitions.
266
+ w_partitions (int): Number of horizontal partitions.
267
+ Returns:
268
+ Tensor: Output tensor with expected layout of [B, C, H, W].
269
+ """
270
+ B, G, PP, C = x.shape
271
+ P = p
272
+ HP, WP = h_partitions, w_partitions
273
+ # split P * P dimension into 2 P tile dimensionsa
274
+ x = x.reshape(B, HP, WP, P, P, C)
275
+ # permute into B, C, HP, P, WP, P
276
+ x = x.permute(0, 5, 1, 3, 2, 4)
277
+ # reshape into B, C, H, W
278
+ x = x.reshape(B, C, HP * P, WP * P)
279
+ return x
280
+
281
+
282
+ class PartitionAttentionLayer(nn.Module):
283
+ """
284
+ Layer for partitioning the input tensor into non-overlapping windows and applying attention to each window.
285
+
286
+ Args:
287
+ in_channels (int): Number of input channels.
288
+ head_dim (int): Dimension of each attention head.
289
+ partition_size (int): Size of the partitions.
290
+ partition_type (str): Type of partitioning to use. Can be either "grid" or "window".
291
+ grid_size (Tuple[int, int]): Size of the grid to partition the input tensor into.
292
+ mlp_ratio (int): Ratio of the feature size expansion in the MLP layer.
293
+ activation_layer (Callable[..., nn.Module]): Activation function to use.
294
+ norm_layer (Callable[..., nn.Module]): Normalization function to use.
295
+ attention_dropout (float): Dropout probability for the attention layer.
296
+ mlp_dropout (float): Dropout probability for the MLP layer.
297
+ p_stochastic_dropout (float): Probability of dropping out a partition.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ in_channels: int,
303
+ head_dim: int,
304
+ # partitioning parameters
305
+ partition_size: int,
306
+ partition_type: str,
307
+ # grid size needs to be known at initialization time
308
+ # because we need to know hamy relative offsets there are in the grid
309
+ grid_size: Tuple[int, int],
310
+ mlp_ratio: int,
311
+ activation_layer: Callable[..., nn.Module],
312
+ norm_layer: Callable[..., nn.Module],
313
+ attention_dropout: float,
314
+ mlp_dropout: float,
315
+ p_stochastic_dropout: float,
316
+ ) -> None:
317
+ super().__init__()
318
+
319
+ self.n_heads = in_channels // head_dim
320
+ self.head_dim = head_dim
321
+ self.n_partitions = grid_size[0] // partition_size
322
+ self.partition_type = partition_type
323
+ self.grid_size = grid_size
324
+
325
+ if partition_type not in ["grid", "window"]:
326
+ raise ValueError("partition_type must be either 'grid' or 'window'")
327
+
328
+ if partition_type == "window":
329
+ self.p, self.g = partition_size, self.n_partitions
330
+ else:
331
+ self.p, self.g = self.n_partitions, partition_size
332
+
333
+ self.partition_op = WindowPartition()
334
+ self.departition_op = WindowDepartition()
335
+ self.partition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
336
+ self.departition_swap = SwapAxes(-2, -3) if partition_type == "grid" else nn.Identity()
337
+
338
+ self.attn_layer = nn.Sequential(
339
+ norm_layer(in_channels),
340
+ # it's always going to be partition_size ** 2 because
341
+ # of the axis swap in the case of grid partitioning
342
+ RelativePositionalMultiHeadAttention(in_channels, head_dim, partition_size**2),
343
+ nn.Dropout(attention_dropout),
344
+ )
345
+
346
+ # pre-normalization similar to transformer layers
347
+ self.mlp_layer = nn.Sequential(
348
+ nn.LayerNorm(in_channels),
349
+ nn.Linear(in_channels, in_channels * mlp_ratio),
350
+ activation_layer(),
351
+ nn.Linear(in_channels * mlp_ratio, in_channels),
352
+ nn.Dropout(mlp_dropout),
353
+ )
354
+
355
+ # layer scale factors
356
+ self.stochastic_dropout = StochasticDepth(p_stochastic_dropout, mode="row")
357
+
358
+ def forward(self, x: Tensor) -> Tensor:
359
+ """
360
+ Args:
361
+ x (Tensor): Input tensor with expected layout of [B, C, H, W].
362
+ Returns:
363
+ Tensor: Output tensor with expected layout of [B, C, H, W].
364
+ """
365
+
366
+ # Undefined behavior if H or W are not divisible by p
367
+ # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
368
+ gh, gw = self.grid_size[0] // self.p, self.grid_size[1] // self.p
369
+ torch._assert(
370
+ self.grid_size[0] % self.p == 0 and self.grid_size[1] % self.p == 0,
371
+ "Grid size must be divisible by partition size. Got grid size of {} and partition size of {}".format(
372
+ self.grid_size, self.p
373
+ ),
374
+ )
375
+
376
+ x = self.partition_op(x, self.p)
377
+ x = self.partition_swap(x)
378
+ x = x + self.stochastic_dropout(self.attn_layer(x))
379
+ x = x + self.stochastic_dropout(self.mlp_layer(x))
380
+ x = self.departition_swap(x)
381
+ x = self.departition_op(x, self.p, gh, gw)
382
+
383
+ return x
384
+
385
+
386
+ class MaxVitLayer(nn.Module):
387
+ """
388
+ MaxVit layer consisting of a MBConv layer followed by a PartitionAttentionLayer with `window` and a PartitionAttentionLayer with `grid`.
389
+
390
+ Args:
391
+ in_channels (int): Number of input channels.
392
+ out_channels (int): Number of output channels.
393
+ expansion_ratio (float): Expansion ratio in the bottleneck.
394
+ squeeze_ratio (float): Squeeze ratio in the SE Layer.
395
+ stride (int): Stride of the depthwise convolution.
396
+ activation_layer (Callable[..., nn.Module]): Activation function.
397
+ norm_layer (Callable[..., nn.Module]): Normalization function.
398
+ head_dim (int): Dimension of the attention heads.
399
+ mlp_ratio (int): Ratio of the MLP layer.
400
+ mlp_dropout (float): Dropout probability for the MLP layer.
401
+ attention_dropout (float): Dropout probability for the attention layer.
402
+ p_stochastic_dropout (float): Probability of stochastic depth.
403
+ partition_size (int): Size of the partitions.
404
+ grid_size (Tuple[int, int]): Size of the input feature grid.
405
+ """
406
+
407
+ def __init__(
408
+ self,
409
+ # conv parameters
410
+ in_channels: int,
411
+ out_channels: int,
412
+ squeeze_ratio: float,
413
+ expansion_ratio: float,
414
+ stride: int,
415
+ # conv + transformer parameters
416
+ norm_layer: Callable[..., nn.Module],
417
+ activation_layer: Callable[..., nn.Module],
418
+ # transformer parameters
419
+ head_dim: int,
420
+ mlp_ratio: int,
421
+ mlp_dropout: float,
422
+ attention_dropout: float,
423
+ p_stochastic_dropout: float,
424
+ # partitioning parameters
425
+ partition_size: int,
426
+ grid_size: Tuple[int, int],
427
+ ) -> None:
428
+ super().__init__()
429
+
430
+ layers: OrderedDict = OrderedDict()
431
+
432
+ # convolutional layer
433
+ layers["MBconv"] = MBConv(
434
+ in_channels=in_channels,
435
+ out_channels=out_channels,
436
+ expansion_ratio=expansion_ratio,
437
+ squeeze_ratio=squeeze_ratio,
438
+ stride=stride,
439
+ activation_layer=activation_layer,
440
+ norm_layer=norm_layer,
441
+ p_stochastic_dropout=p_stochastic_dropout,
442
+ )
443
+ # attention layers, block -> grid
444
+ layers["window_attention"] = PartitionAttentionLayer(
445
+ in_channels=out_channels,
446
+ head_dim=head_dim,
447
+ partition_size=partition_size,
448
+ partition_type="window",
449
+ grid_size=grid_size,
450
+ mlp_ratio=mlp_ratio,
451
+ activation_layer=activation_layer,
452
+ norm_layer=nn.LayerNorm,
453
+ attention_dropout=attention_dropout,
454
+ mlp_dropout=mlp_dropout,
455
+ p_stochastic_dropout=p_stochastic_dropout,
456
+ )
457
+ layers["grid_attention"] = PartitionAttentionLayer(
458
+ in_channels=out_channels,
459
+ head_dim=head_dim,
460
+ partition_size=partition_size,
461
+ partition_type="grid",
462
+ grid_size=grid_size,
463
+ mlp_ratio=mlp_ratio,
464
+ activation_layer=activation_layer,
465
+ norm_layer=nn.LayerNorm,
466
+ attention_dropout=attention_dropout,
467
+ mlp_dropout=mlp_dropout,
468
+ p_stochastic_dropout=p_stochastic_dropout,
469
+ )
470
+ self.layers = nn.Sequential(layers)
471
+
472
+ def forward(self, x: Tensor) -> Tensor:
473
+ """
474
+ Args:
475
+ x (Tensor): Input tensor of shape (B, C, H, W).
476
+ Returns:
477
+ Tensor: Output tensor of shape (B, C, H, W).
478
+ """
479
+ x = self.layers(x)
480
+ return x
481
+
482
+
483
+ class MaxVitBlock(nn.Module):
484
+ """
485
+ A MaxVit block consisting of `n_layers` MaxVit layers.
486
+
487
+ Args:
488
+ in_channels (int): Number of input channels.
489
+ out_channels (int): Number of output channels.
490
+ expansion_ratio (float): Expansion ratio in the bottleneck.
491
+ squeeze_ratio (float): Squeeze ratio in the SE Layer.
492
+ activation_layer (Callable[..., nn.Module]): Activation function.
493
+ norm_layer (Callable[..., nn.Module]): Normalization function.
494
+ head_dim (int): Dimension of the attention heads.
495
+ mlp_ratio (int): Ratio of the MLP layer.
496
+ mlp_dropout (float): Dropout probability for the MLP layer.
497
+ attention_dropout (float): Dropout probability for the attention layer.
498
+ p_stochastic_dropout (float): Probability of stochastic depth.
499
+ partition_size (int): Size of the partitions.
500
+ input_grid_size (Tuple[int, int]): Size of the input feature grid.
501
+ n_layers (int): Number of layers in the block.
502
+ p_stochastic (List[float]): List of probabilities for stochastic depth for each layer.
503
+ """
504
+
505
+ def __init__(
506
+ self,
507
+ # conv parameters
508
+ in_channels: int,
509
+ out_channels: int,
510
+ squeeze_ratio: float,
511
+ expansion_ratio: float,
512
+ # conv + transformer parameters
513
+ norm_layer: Callable[..., nn.Module],
514
+ activation_layer: Callable[..., nn.Module],
515
+ # transformer parameters
516
+ head_dim: int,
517
+ mlp_ratio: int,
518
+ mlp_dropout: float,
519
+ attention_dropout: float,
520
+ # partitioning parameters
521
+ partition_size: int,
522
+ input_grid_size: Tuple[int, int],
523
+ # number of layers
524
+ n_layers: int,
525
+ p_stochastic: List[float],
526
+ ) -> None:
527
+ super().__init__()
528
+ if not len(p_stochastic) == n_layers:
529
+ raise ValueError(f"p_stochastic must have length n_layers={n_layers}, got p_stochastic={p_stochastic}.")
530
+
531
+ self.layers = nn.ModuleList()
532
+ # account for the first stride of the first layer
533
+ self.grid_size = _get_conv_output_shape(input_grid_size, kernel_size=3, stride=2, padding=1)
534
+
535
+ for idx, p in enumerate(p_stochastic):
536
+ stride = 2 if idx == 0 else 1
537
+ self.layers += [
538
+ MaxVitLayer(
539
+ in_channels=in_channels if idx == 0 else out_channels,
540
+ out_channels=out_channels,
541
+ squeeze_ratio=squeeze_ratio,
542
+ expansion_ratio=expansion_ratio,
543
+ stride=stride,
544
+ norm_layer=norm_layer,
545
+ activation_layer=activation_layer,
546
+ head_dim=head_dim,
547
+ mlp_ratio=mlp_ratio,
548
+ mlp_dropout=mlp_dropout,
549
+ attention_dropout=attention_dropout,
550
+ partition_size=partition_size,
551
+ grid_size=self.grid_size,
552
+ p_stochastic_dropout=p,
553
+ ),
554
+ ]
555
+
556
+ def forward(self, x: Tensor) -> Tensor:
557
+ """
558
+ Args:
559
+ x (Tensor): Input tensor of shape (B, C, H, W).
560
+ Returns:
561
+ Tensor: Output tensor of shape (B, C, H, W).
562
+ """
563
+ for layer in self.layers:
564
+ x = layer(x)
565
+ return x
566
+
567
+
568
+ class MaxVit(nn.Module):
569
+ """
570
+ Implements MaxVit Transformer from the `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_ paper.
571
+ Args:
572
+ input_size (Tuple[int, int]): Size of the input image.
573
+ stem_channels (int): Number of channels in the stem.
574
+ partition_size (int): Size of the partitions.
575
+ block_channels (List[int]): Number of channels in each block.
576
+ block_layers (List[int]): Number of layers in each block.
577
+ stochastic_depth_prob (float): Probability of stochastic depth. Expands to a list of probabilities for each layer that scales linearly to the specified value.
578
+ squeeze_ratio (float): Squeeze ratio in the SE Layer. Default: 0.25.
579
+ expansion_ratio (float): Expansion ratio in the MBConv bottleneck. Default: 4.
580
+ norm_layer (Callable[..., nn.Module]): Normalization function. Default: None (setting to None will produce a `BatchNorm2d(eps=1e-3, momentum=0.01)`).
581
+ activation_layer (Callable[..., nn.Module]): Activation function Default: nn.GELU.
582
+ head_dim (int): Dimension of the attention heads.
583
+ mlp_ratio (int): Expansion ratio of the MLP layer. Default: 4.
584
+ mlp_dropout (float): Dropout probability for the MLP layer. Default: 0.0.
585
+ attention_dropout (float): Dropout probability for the attention layer. Default: 0.0.
586
+ num_classes (int): Number of classes. Default: 1000.
587
+ """
588
+
589
+ def __init__(
590
+ self,
591
+ # input size parameters
592
+ input_size: Tuple[int, int],
593
+ # stem and task parameters
594
+ stem_channels: int,
595
+ # partitioning parameters
596
+ partition_size: int,
597
+ # block parameters
598
+ block_channels: List[int],
599
+ block_layers: List[int],
600
+ # attention head dimensions
601
+ head_dim: int,
602
+ stochastic_depth_prob: float,
603
+ # conv + transformer parameters
604
+ # norm_layer is applied only to the conv layers
605
+ # activation_layer is applied both to conv and transformer layers
606
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
607
+ activation_layer: Callable[..., nn.Module] = nn.GELU,
608
+ # conv parameters
609
+ squeeze_ratio: float = 0.25,
610
+ expansion_ratio: float = 4,
611
+ # transformer parameters
612
+ mlp_ratio: int = 4,
613
+ mlp_dropout: float = 0.0,
614
+ attention_dropout: float = 0.0,
615
+ # task parameters
616
+ num_classes: int = 1000,
617
+ ) -> None:
618
+ super().__init__()
619
+ _log_api_usage_once(self)
620
+
621
+ input_channels = 3
622
+
623
+ # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1029-L1030
624
+ # for the exact parameters used in batchnorm
625
+ if norm_layer is None:
626
+ norm_layer = partial(nn.BatchNorm2d, eps=1e-3, momentum=0.01)
627
+
628
+ # Make sure input size will be divisible by the partition size in all blocks
629
+ # Undefined behavior if H or W are not divisible by p
630
+ # https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L766
631
+ block_input_sizes = _make_block_input_shapes(input_size, len(block_channels))
632
+ for idx, block_input_size in enumerate(block_input_sizes):
633
+ if block_input_size[0] % partition_size != 0 or block_input_size[1] % partition_size != 0:
634
+ raise ValueError(
635
+ f"Input size {block_input_size} of block {idx} is not divisible by partition size {partition_size}. "
636
+ f"Consider changing the partition size or the input size.\n"
637
+ f"Current configuration yields the following block input sizes: {block_input_sizes}."
638
+ )
639
+
640
+ # stem
641
+ self.stem = nn.Sequential(
642
+ Conv2dNormActivation(
643
+ input_channels,
644
+ stem_channels,
645
+ 3,
646
+ stride=2,
647
+ norm_layer=norm_layer,
648
+ activation_layer=activation_layer,
649
+ bias=False,
650
+ inplace=None,
651
+ ),
652
+ Conv2dNormActivation(
653
+ stem_channels, stem_channels, 3, stride=1, norm_layer=None, activation_layer=None, bias=True
654
+ ),
655
+ )
656
+
657
+ # account for stem stride
658
+ input_size = _get_conv_output_shape(input_size, kernel_size=3, stride=2, padding=1)
659
+ self.partition_size = partition_size
660
+
661
+ # blocks
662
+ self.blocks = nn.ModuleList()
663
+ in_channels = [stem_channels] + block_channels[:-1]
664
+ out_channels = block_channels
665
+
666
+ # precompute the stochastich depth probabilities from 0 to stochastic_depth_prob
667
+ # since we have N blocks with L layers, we will have N * L probabilities uniformly distributed
668
+ # over the range [0, stochastic_depth_prob]
669
+ p_stochastic = np.linspace(0, stochastic_depth_prob, sum(block_layers)).tolist()
670
+
671
+ p_idx = 0
672
+ for in_channel, out_channel, num_layers in zip(in_channels, out_channels, block_layers):
673
+ self.blocks.append(
674
+ MaxVitBlock(
675
+ in_channels=in_channel,
676
+ out_channels=out_channel,
677
+ squeeze_ratio=squeeze_ratio,
678
+ expansion_ratio=expansion_ratio,
679
+ norm_layer=norm_layer,
680
+ activation_layer=activation_layer,
681
+ head_dim=head_dim,
682
+ mlp_ratio=mlp_ratio,
683
+ mlp_dropout=mlp_dropout,
684
+ attention_dropout=attention_dropout,
685
+ partition_size=partition_size,
686
+ input_grid_size=input_size,
687
+ n_layers=num_layers,
688
+ p_stochastic=p_stochastic[p_idx : p_idx + num_layers],
689
+ ),
690
+ )
691
+ input_size = self.blocks[-1].grid_size
692
+ p_idx += num_layers
693
+
694
+ # see https://github.com/google-research/maxvit/blob/da76cf0d8a6ec668cc31b399c4126186da7da944/maxvit/models/maxvit.py#L1137-L1158
695
+ # for why there is Linear -> Tanh -> Linear
696
+ self.classifier = nn.Sequential(
697
+ nn.AdaptiveAvgPool2d(1),
698
+ nn.Flatten(),
699
+ nn.LayerNorm(block_channels[-1]),
700
+ nn.Linear(block_channels[-1], block_channels[-1]),
701
+ nn.Tanh(),
702
+ nn.Linear(block_channels[-1], num_classes, bias=False),
703
+ )
704
+
705
+ self._init_weights()
706
+
707
+ def forward(self, x: Tensor) -> Tensor:
708
+ x = self.stem(x)
709
+ for block in self.blocks:
710
+ x = block(x)
711
+ x = self.classifier(x)
712
+ return x
713
+
714
+ def _init_weights(self):
715
+ for m in self.modules():
716
+ if isinstance(m, nn.Conv2d):
717
+ nn.init.normal_(m.weight, std=0.02)
718
+ if m.bias is not None:
719
+ nn.init.zeros_(m.bias)
720
+ elif isinstance(m, nn.BatchNorm2d):
721
+ nn.init.constant_(m.weight, 1)
722
+ nn.init.constant_(m.bias, 0)
723
+ elif isinstance(m, nn.Linear):
724
+ nn.init.normal_(m.weight, std=0.02)
725
+ if m.bias is not None:
726
+ nn.init.zeros_(m.bias)
727
+
728
+
729
+ def _maxvit(
730
+ # stem parameters
731
+ stem_channels: int,
732
+ # block parameters
733
+ block_channels: List[int],
734
+ block_layers: List[int],
735
+ stochastic_depth_prob: float,
736
+ # partitioning parameters
737
+ partition_size: int,
738
+ # transformer parameters
739
+ head_dim: int,
740
+ # Weights API
741
+ weights: Optional[WeightsEnum] = None,
742
+ progress: bool = False,
743
+ # kwargs,
744
+ **kwargs: Any,
745
+ ) -> MaxVit:
746
+
747
+ if weights is not None:
748
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
749
+ assert weights.meta["min_size"][0] == weights.meta["min_size"][1]
750
+ _ovewrite_named_param(kwargs, "input_size", weights.meta["min_size"])
751
+
752
+ input_size = kwargs.pop("input_size", (224, 224))
753
+
754
+ model = MaxVit(
755
+ stem_channels=stem_channels,
756
+ block_channels=block_channels,
757
+ block_layers=block_layers,
758
+ stochastic_depth_prob=stochastic_depth_prob,
759
+ head_dim=head_dim,
760
+ partition_size=partition_size,
761
+ input_size=input_size,
762
+ **kwargs,
763
+ )
764
+
765
+ if weights is not None:
766
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
767
+
768
+ return model
769
+
770
+
771
+ class MaxVit_T_Weights(WeightsEnum):
772
+ IMAGENET1K_V1 = Weights(
773
+ # URL empty until official release
774
+ url="https://download.pytorch.org/models/maxvit_t-bc5ab103.pth",
775
+ transforms=partial(
776
+ ImageClassification, crop_size=224, resize_size=224, interpolation=InterpolationMode.BICUBIC
777
+ ),
778
+ meta={
779
+ "categories": _IMAGENET_CATEGORIES,
780
+ "num_params": 30919624,
781
+ "min_size": (224, 224),
782
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#maxvit",
783
+ "_metrics": {
784
+ "ImageNet-1K": {
785
+ "acc@1": 83.700,
786
+ "acc@5": 96.722,
787
+ }
788
+ },
789
+ "_ops": 5.558,
790
+ "_file_size": 118.769,
791
+ "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.
792
+ They were trained with a BatchNorm2D momentum of 0.99 instead of the more correct 0.01.""",
793
+ },
794
+ )
795
+ DEFAULT = IMAGENET1K_V1
796
+
797
+
798
+ @register_model()
799
+ @handle_legacy_interface(weights=("pretrained", MaxVit_T_Weights.IMAGENET1K_V1))
800
+ def maxvit_t(*, weights: Optional[MaxVit_T_Weights] = None, progress: bool = True, **kwargs: Any) -> MaxVit:
801
+ """
802
+ Constructs a maxvit_t architecture from
803
+ `MaxViT: Multi-Axis Vision Transformer <https://arxiv.org/abs/2204.01697>`_.
804
+
805
+ Args:
806
+ weights (:class:`~torchvision.models.MaxVit_T_Weights`, optional): The
807
+ pretrained weights to use. See
808
+ :class:`~torchvision.models.MaxVit_T_Weights` below for
809
+ more details, and possible values. By default, no pre-trained
810
+ weights are used.
811
+ progress (bool, optional): If True, displays a progress bar of the
812
+ download to stderr. Default is True.
813
+ **kwargs: parameters passed to the ``torchvision.models.maxvit.MaxVit``
814
+ base class. Please refer to the `source code
815
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/maxvit.py>`_
816
+ for more details about this class.
817
+
818
+ .. autoclass:: torchvision.models.MaxVit_T_Weights
819
+ :members:
820
+ """
821
+ weights = MaxVit_T_Weights.verify(weights)
822
+
823
+ return _maxvit(
824
+ stem_channels=64,
825
+ block_channels=[64, 128, 256, 512],
826
+ block_layers=[2, 2, 5, 2],
827
+ head_dim=32,
828
+ stochastic_depth_prob=0.2,
829
+ partition_size=7,
830
+ weights=weights,
831
+ progress=progress,
832
+ **kwargs,
833
+ )
.venv/lib/python3.11/site-packages/torchvision/models/mobilenetv3.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, List, Optional, Sequence
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+
7
+ from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer
8
+ from ..transforms._presets import ImageClassification
9
+ from ..utils import _log_api_usage_once
10
+ from ._api import register_model, Weights, WeightsEnum
11
+ from ._meta import _IMAGENET_CATEGORIES
12
+ from ._utils import _make_divisible, _ovewrite_named_param, handle_legacy_interface
13
+
14
+
15
+ __all__ = [
16
+ "MobileNetV3",
17
+ "MobileNet_V3_Large_Weights",
18
+ "MobileNet_V3_Small_Weights",
19
+ "mobilenet_v3_large",
20
+ "mobilenet_v3_small",
21
+ ]
22
+
23
+
24
+ class InvertedResidualConfig:
25
+ # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper
26
+ def __init__(
27
+ self,
28
+ input_channels: int,
29
+ kernel: int,
30
+ expanded_channels: int,
31
+ out_channels: int,
32
+ use_se: bool,
33
+ activation: str,
34
+ stride: int,
35
+ dilation: int,
36
+ width_mult: float,
37
+ ):
38
+ self.input_channels = self.adjust_channels(input_channels, width_mult)
39
+ self.kernel = kernel
40
+ self.expanded_channels = self.adjust_channels(expanded_channels, width_mult)
41
+ self.out_channels = self.adjust_channels(out_channels, width_mult)
42
+ self.use_se = use_se
43
+ self.use_hs = activation == "HS"
44
+ self.stride = stride
45
+ self.dilation = dilation
46
+
47
+ @staticmethod
48
+ def adjust_channels(channels: int, width_mult: float):
49
+ return _make_divisible(channels * width_mult, 8)
50
+
51
+
52
+ class InvertedResidual(nn.Module):
53
+ # Implemented as described at section 5 of MobileNetV3 paper
54
+ def __init__(
55
+ self,
56
+ cnf: InvertedResidualConfig,
57
+ norm_layer: Callable[..., nn.Module],
58
+ se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid),
59
+ ):
60
+ super().__init__()
61
+ if not (1 <= cnf.stride <= 2):
62
+ raise ValueError("illegal stride value")
63
+
64
+ self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels
65
+
66
+ layers: List[nn.Module] = []
67
+ activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU
68
+
69
+ # expand
70
+ if cnf.expanded_channels != cnf.input_channels:
71
+ layers.append(
72
+ Conv2dNormActivation(
73
+ cnf.input_channels,
74
+ cnf.expanded_channels,
75
+ kernel_size=1,
76
+ norm_layer=norm_layer,
77
+ activation_layer=activation_layer,
78
+ )
79
+ )
80
+
81
+ # depthwise
82
+ stride = 1 if cnf.dilation > 1 else cnf.stride
83
+ layers.append(
84
+ Conv2dNormActivation(
85
+ cnf.expanded_channels,
86
+ cnf.expanded_channels,
87
+ kernel_size=cnf.kernel,
88
+ stride=stride,
89
+ dilation=cnf.dilation,
90
+ groups=cnf.expanded_channels,
91
+ norm_layer=norm_layer,
92
+ activation_layer=activation_layer,
93
+ )
94
+ )
95
+ if cnf.use_se:
96
+ squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8)
97
+ layers.append(se_layer(cnf.expanded_channels, squeeze_channels))
98
+
99
+ # project
100
+ layers.append(
101
+ Conv2dNormActivation(
102
+ cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
103
+ )
104
+ )
105
+
106
+ self.block = nn.Sequential(*layers)
107
+ self.out_channels = cnf.out_channels
108
+ self._is_cn = cnf.stride > 1
109
+
110
+ def forward(self, input: Tensor) -> Tensor:
111
+ result = self.block(input)
112
+ if self.use_res_connect:
113
+ result += input
114
+ return result
115
+
116
+
117
+ class MobileNetV3(nn.Module):
118
+ def __init__(
119
+ self,
120
+ inverted_residual_setting: List[InvertedResidualConfig],
121
+ last_channel: int,
122
+ num_classes: int = 1000,
123
+ block: Optional[Callable[..., nn.Module]] = None,
124
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
125
+ dropout: float = 0.2,
126
+ **kwargs: Any,
127
+ ) -> None:
128
+ """
129
+ MobileNet V3 main class
130
+
131
+ Args:
132
+ inverted_residual_setting (List[InvertedResidualConfig]): Network structure
133
+ last_channel (int): The number of channels on the penultimate layer
134
+ num_classes (int): Number of classes
135
+ block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet
136
+ norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
137
+ dropout (float): The droupout probability
138
+ """
139
+ super().__init__()
140
+ _log_api_usage_once(self)
141
+
142
+ if not inverted_residual_setting:
143
+ raise ValueError("The inverted_residual_setting should not be empty")
144
+ elif not (
145
+ isinstance(inverted_residual_setting, Sequence)
146
+ and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])
147
+ ):
148
+ raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]")
149
+
150
+ if block is None:
151
+ block = InvertedResidual
152
+
153
+ if norm_layer is None:
154
+ norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)
155
+
156
+ layers: List[nn.Module] = []
157
+
158
+ # building first layer
159
+ firstconv_output_channels = inverted_residual_setting[0].input_channels
160
+ layers.append(
161
+ Conv2dNormActivation(
162
+ 3,
163
+ firstconv_output_channels,
164
+ kernel_size=3,
165
+ stride=2,
166
+ norm_layer=norm_layer,
167
+ activation_layer=nn.Hardswish,
168
+ )
169
+ )
170
+
171
+ # building inverted residual blocks
172
+ for cnf in inverted_residual_setting:
173
+ layers.append(block(cnf, norm_layer))
174
+
175
+ # building last several layers
176
+ lastconv_input_channels = inverted_residual_setting[-1].out_channels
177
+ lastconv_output_channels = 6 * lastconv_input_channels
178
+ layers.append(
179
+ Conv2dNormActivation(
180
+ lastconv_input_channels,
181
+ lastconv_output_channels,
182
+ kernel_size=1,
183
+ norm_layer=norm_layer,
184
+ activation_layer=nn.Hardswish,
185
+ )
186
+ )
187
+
188
+ self.features = nn.Sequential(*layers)
189
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
190
+ self.classifier = nn.Sequential(
191
+ nn.Linear(lastconv_output_channels, last_channel),
192
+ nn.Hardswish(inplace=True),
193
+ nn.Dropout(p=dropout, inplace=True),
194
+ nn.Linear(last_channel, num_classes),
195
+ )
196
+
197
+ for m in self.modules():
198
+ if isinstance(m, nn.Conv2d):
199
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
200
+ if m.bias is not None:
201
+ nn.init.zeros_(m.bias)
202
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
203
+ nn.init.ones_(m.weight)
204
+ nn.init.zeros_(m.bias)
205
+ elif isinstance(m, nn.Linear):
206
+ nn.init.normal_(m.weight, 0, 0.01)
207
+ nn.init.zeros_(m.bias)
208
+
209
+ def _forward_impl(self, x: Tensor) -> Tensor:
210
+ x = self.features(x)
211
+
212
+ x = self.avgpool(x)
213
+ x = torch.flatten(x, 1)
214
+
215
+ x = self.classifier(x)
216
+
217
+ return x
218
+
219
+ def forward(self, x: Tensor) -> Tensor:
220
+ return self._forward_impl(x)
221
+
222
+
223
+ def _mobilenet_v3_conf(
224
+ arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any
225
+ ):
226
+ reduce_divider = 2 if reduced_tail else 1
227
+ dilation = 2 if dilated else 1
228
+
229
+ bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
230
+ adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)
231
+
232
+ if arch == "mobilenet_v3_large":
233
+ inverted_residual_setting = [
234
+ bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
235
+ bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
236
+ bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
237
+ bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
238
+ bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
239
+ bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
240
+ bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
241
+ bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
242
+ bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
243
+ bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
244
+ bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
245
+ bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
246
+ bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
247
+ bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
248
+ bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
249
+ ]
250
+ last_channel = adjust_channels(1280 // reduce_divider) # C5
251
+ elif arch == "mobilenet_v3_small":
252
+ inverted_residual_setting = [
253
+ bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
254
+ bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
255
+ bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
256
+ bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
257
+ bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
258
+ bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
259
+ bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
260
+ bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
261
+ bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
262
+ bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
263
+ bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
264
+ ]
265
+ last_channel = adjust_channels(1024 // reduce_divider) # C5
266
+ else:
267
+ raise ValueError(f"Unsupported model type {arch}")
268
+
269
+ return inverted_residual_setting, last_channel
270
+
271
+
272
+ def _mobilenet_v3(
273
+ inverted_residual_setting: List[InvertedResidualConfig],
274
+ last_channel: int,
275
+ weights: Optional[WeightsEnum],
276
+ progress: bool,
277
+ **kwargs: Any,
278
+ ) -> MobileNetV3:
279
+ if weights is not None:
280
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
281
+
282
+ model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs)
283
+
284
+ if weights is not None:
285
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
286
+
287
+ return model
288
+
289
+
290
+ _COMMON_META = {
291
+ "min_size": (1, 1),
292
+ "categories": _IMAGENET_CATEGORIES,
293
+ }
294
+
295
+
296
+ class MobileNet_V3_Large_Weights(WeightsEnum):
297
+ IMAGENET1K_V1 = Weights(
298
+ url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
299
+ transforms=partial(ImageClassification, crop_size=224),
300
+ meta={
301
+ **_COMMON_META,
302
+ "num_params": 5483032,
303
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
304
+ "_metrics": {
305
+ "ImageNet-1K": {
306
+ "acc@1": 74.042,
307
+ "acc@5": 91.340,
308
+ }
309
+ },
310
+ "_ops": 0.217,
311
+ "_file_size": 21.114,
312
+ "_docs": """These weights were trained from scratch by using a simple training recipe.""",
313
+ },
314
+ )
315
+ IMAGENET1K_V2 = Weights(
316
+ url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
317
+ transforms=partial(ImageClassification, crop_size=224, resize_size=232),
318
+ meta={
319
+ **_COMMON_META,
320
+ "num_params": 5483032,
321
+ "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
322
+ "_metrics": {
323
+ "ImageNet-1K": {
324
+ "acc@1": 75.274,
325
+ "acc@5": 92.566,
326
+ }
327
+ },
328
+ "_ops": 0.217,
329
+ "_file_size": 21.107,
330
+ "_docs": """
331
+ These weights improve marginally upon the results of the original paper by using a modified version of
332
+ TorchVision's `new training recipe
333
+ <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
334
+ """,
335
+ },
336
+ )
337
+ DEFAULT = IMAGENET1K_V2
338
+
339
+
340
+ class MobileNet_V3_Small_Weights(WeightsEnum):
341
+ IMAGENET1K_V1 = Weights(
342
+ url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
343
+ transforms=partial(ImageClassification, crop_size=224),
344
+ meta={
345
+ **_COMMON_META,
346
+ "num_params": 2542856,
347
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small",
348
+ "_metrics": {
349
+ "ImageNet-1K": {
350
+ "acc@1": 67.668,
351
+ "acc@5": 87.402,
352
+ }
353
+ },
354
+ "_ops": 0.057,
355
+ "_file_size": 9.829,
356
+ "_docs": """
357
+ These weights improve upon the results of the original paper by using a simple training recipe.
358
+ """,
359
+ },
360
+ )
361
+ DEFAULT = IMAGENET1K_V1
362
+
363
+
364
+ @register_model()
365
+ @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Large_Weights.IMAGENET1K_V1))
366
+ def mobilenet_v3_large(
367
+ *, weights: Optional[MobileNet_V3_Large_Weights] = None, progress: bool = True, **kwargs: Any
368
+ ) -> MobileNetV3:
369
+ """
370
+ Constructs a large MobileNetV3 architecture from
371
+ `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
372
+
373
+ Args:
374
+ weights (:class:`~torchvision.models.MobileNet_V3_Large_Weights`, optional): The
375
+ pretrained weights to use. See
376
+ :class:`~torchvision.models.MobileNet_V3_Large_Weights` below for
377
+ more details, and possible values. By default, no pre-trained
378
+ weights are used.
379
+ progress (bool, optional): If True, displays a progress bar of the
380
+ download to stderr. Default is True.
381
+ **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
382
+ base class. Please refer to the `source code
383
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
384
+ for more details about this class.
385
+
386
+ .. autoclass:: torchvision.models.MobileNet_V3_Large_Weights
387
+ :members:
388
+ """
389
+ weights = MobileNet_V3_Large_Weights.verify(weights)
390
+
391
+ inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
392
+ return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
393
+
394
+
395
+ @register_model()
396
+ @handle_legacy_interface(weights=("pretrained", MobileNet_V3_Small_Weights.IMAGENET1K_V1))
397
+ def mobilenet_v3_small(
398
+ *, weights: Optional[MobileNet_V3_Small_Weights] = None, progress: bool = True, **kwargs: Any
399
+ ) -> MobileNetV3:
400
+ """
401
+ Constructs a small MobileNetV3 architecture from
402
+ `Searching for MobileNetV3 <https://arxiv.org/abs/1905.02244>`__.
403
+
404
+ Args:
405
+ weights (:class:`~torchvision.models.MobileNet_V3_Small_Weights`, optional): The
406
+ pretrained weights to use. See
407
+ :class:`~torchvision.models.MobileNet_V3_Small_Weights` below for
408
+ more details, and possible values. By default, no pre-trained
409
+ weights are used.
410
+ progress (bool, optional): If True, displays a progress bar of the
411
+ download to stderr. Default is True.
412
+ **kwargs: parameters passed to the ``torchvision.models.mobilenet.MobileNetV3``
413
+ base class. Please refer to the `source code
414
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv3.py>`_
415
+ for more details about this class.
416
+
417
+ .. autoclass:: torchvision.models.MobileNet_V3_Small_Weights
418
+ :members:
419
+ """
420
+ weights = MobileNet_V3_Small_Weights.verify(weights)
421
+
422
+ inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs)
423
+ return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)
.venv/lib/python3.11/site-packages/torchvision/models/squeezenet.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.init as init
7
+
8
+ from ..transforms._presets import ImageClassification
9
+ from ..utils import _log_api_usage_once
10
+ from ._api import register_model, Weights, WeightsEnum
11
+ from ._meta import _IMAGENET_CATEGORIES
12
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
13
+
14
+
15
+ __all__ = ["SqueezeNet", "SqueezeNet1_0_Weights", "SqueezeNet1_1_Weights", "squeezenet1_0", "squeezenet1_1"]
16
+
17
+
18
+ class Fire(nn.Module):
19
+ def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None:
20
+ super().__init__()
21
+ self.inplanes = inplanes
22
+ self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
23
+ self.squeeze_activation = nn.ReLU(inplace=True)
24
+ self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
25
+ self.expand1x1_activation = nn.ReLU(inplace=True)
26
+ self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)
27
+ self.expand3x3_activation = nn.ReLU(inplace=True)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ x = self.squeeze_activation(self.squeeze(x))
31
+ return torch.cat(
32
+ [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1
33
+ )
34
+
35
+
36
+ class SqueezeNet(nn.Module):
37
+ def __init__(self, version: str = "1_0", num_classes: int = 1000, dropout: float = 0.5) -> None:
38
+ super().__init__()
39
+ _log_api_usage_once(self)
40
+ self.num_classes = num_classes
41
+ if version == "1_0":
42
+ self.features = nn.Sequential(
43
+ nn.Conv2d(3, 96, kernel_size=7, stride=2),
44
+ nn.ReLU(inplace=True),
45
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
46
+ Fire(96, 16, 64, 64),
47
+ Fire(128, 16, 64, 64),
48
+ Fire(128, 32, 128, 128),
49
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
50
+ Fire(256, 32, 128, 128),
51
+ Fire(256, 48, 192, 192),
52
+ Fire(384, 48, 192, 192),
53
+ Fire(384, 64, 256, 256),
54
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
55
+ Fire(512, 64, 256, 256),
56
+ )
57
+ elif version == "1_1":
58
+ self.features = nn.Sequential(
59
+ nn.Conv2d(3, 64, kernel_size=3, stride=2),
60
+ nn.ReLU(inplace=True),
61
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
62
+ Fire(64, 16, 64, 64),
63
+ Fire(128, 16, 64, 64),
64
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
65
+ Fire(128, 32, 128, 128),
66
+ Fire(256, 32, 128, 128),
67
+ nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
68
+ Fire(256, 48, 192, 192),
69
+ Fire(384, 48, 192, 192),
70
+ Fire(384, 64, 256, 256),
71
+ Fire(512, 64, 256, 256),
72
+ )
73
+ else:
74
+ # FIXME: Is this needed? SqueezeNet should only be called from the
75
+ # FIXME: squeezenet1_x() functions
76
+ # FIXME: This checking is not done for the other models
77
+ raise ValueError(f"Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected")
78
+
79
+ # Final convolution is initialized differently from the rest
80
+ final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
81
+ self.classifier = nn.Sequential(
82
+ nn.Dropout(p=dropout), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1))
83
+ )
84
+
85
+ for m in self.modules():
86
+ if isinstance(m, nn.Conv2d):
87
+ if m is final_conv:
88
+ init.normal_(m.weight, mean=0.0, std=0.01)
89
+ else:
90
+ init.kaiming_uniform_(m.weight)
91
+ if m.bias is not None:
92
+ init.constant_(m.bias, 0)
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ x = self.features(x)
96
+ x = self.classifier(x)
97
+ return torch.flatten(x, 1)
98
+
99
+
100
+ def _squeezenet(
101
+ version: str,
102
+ weights: Optional[WeightsEnum],
103
+ progress: bool,
104
+ **kwargs: Any,
105
+ ) -> SqueezeNet:
106
+ if weights is not None:
107
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
108
+
109
+ model = SqueezeNet(version, **kwargs)
110
+
111
+ if weights is not None:
112
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
113
+
114
+ return model
115
+
116
+
117
+ _COMMON_META = {
118
+ "categories": _IMAGENET_CATEGORIES,
119
+ "recipe": "https://github.com/pytorch/vision/pull/49#issuecomment-277560717",
120
+ "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
121
+ }
122
+
123
+
124
+ class SqueezeNet1_0_Weights(WeightsEnum):
125
+ IMAGENET1K_V1 = Weights(
126
+ url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
127
+ transforms=partial(ImageClassification, crop_size=224),
128
+ meta={
129
+ **_COMMON_META,
130
+ "min_size": (21, 21),
131
+ "num_params": 1248424,
132
+ "_metrics": {
133
+ "ImageNet-1K": {
134
+ "acc@1": 58.092,
135
+ "acc@5": 80.420,
136
+ }
137
+ },
138
+ "_ops": 0.819,
139
+ "_file_size": 4.778,
140
+ },
141
+ )
142
+ DEFAULT = IMAGENET1K_V1
143
+
144
+
145
+ class SqueezeNet1_1_Weights(WeightsEnum):
146
+ IMAGENET1K_V1 = Weights(
147
+ url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
148
+ transforms=partial(ImageClassification, crop_size=224),
149
+ meta={
150
+ **_COMMON_META,
151
+ "min_size": (17, 17),
152
+ "num_params": 1235496,
153
+ "_metrics": {
154
+ "ImageNet-1K": {
155
+ "acc@1": 58.178,
156
+ "acc@5": 80.624,
157
+ }
158
+ },
159
+ "_ops": 0.349,
160
+ "_file_size": 4.729,
161
+ },
162
+ )
163
+ DEFAULT = IMAGENET1K_V1
164
+
165
+
166
+ @register_model()
167
+ @handle_legacy_interface(weights=("pretrained", SqueezeNet1_0_Weights.IMAGENET1K_V1))
168
+ def squeezenet1_0(
169
+ *, weights: Optional[SqueezeNet1_0_Weights] = None, progress: bool = True, **kwargs: Any
170
+ ) -> SqueezeNet:
171
+ """SqueezeNet model architecture from the `SqueezeNet: AlexNet-level
172
+ accuracy with 50x fewer parameters and <0.5MB model size
173
+ <https://arxiv.org/abs/1602.07360>`_ paper.
174
+
175
+ Args:
176
+ weights (:class:`~torchvision.models.SqueezeNet1_0_Weights`, optional): The
177
+ pretrained weights to use. See
178
+ :class:`~torchvision.models.SqueezeNet1_0_Weights` below for
179
+ more details, and possible values. By default, no pre-trained
180
+ weights are used.
181
+ progress (bool, optional): If True, displays a progress bar of the
182
+ download to stderr. Default is True.
183
+ **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
184
+ base class. Please refer to the `source code
185
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
186
+ for more details about this class.
187
+
188
+ .. autoclass:: torchvision.models.SqueezeNet1_0_Weights
189
+ :members:
190
+ """
191
+ weights = SqueezeNet1_0_Weights.verify(weights)
192
+ return _squeezenet("1_0", weights, progress, **kwargs)
193
+
194
+
195
+ @register_model()
196
+ @handle_legacy_interface(weights=("pretrained", SqueezeNet1_1_Weights.IMAGENET1K_V1))
197
+ def squeezenet1_1(
198
+ *, weights: Optional[SqueezeNet1_1_Weights] = None, progress: bool = True, **kwargs: Any
199
+ ) -> SqueezeNet:
200
+ """SqueezeNet 1.1 model from the `official SqueezeNet repo
201
+ <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.
202
+
203
+ SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
204
+ than SqueezeNet 1.0, without sacrificing accuracy.
205
+
206
+ Args:
207
+ weights (:class:`~torchvision.models.SqueezeNet1_1_Weights`, optional): The
208
+ pretrained weights to use. See
209
+ :class:`~torchvision.models.SqueezeNet1_1_Weights` below for
210
+ more details, and possible values. By default, no pre-trained
211
+ weights are used.
212
+ progress (bool, optional): If True, displays a progress bar of the
213
+ download to stderr. Default is True.
214
+ **kwargs: parameters passed to the ``torchvision.models.squeezenet.SqueezeNet``
215
+ base class. Please refer to the `source code
216
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/squeezenet.py>`_
217
+ for more details about this class.
218
+
219
+ .. autoclass:: torchvision.models.SqueezeNet1_1_Weights
220
+ :members:
221
+ """
222
+ weights = SqueezeNet1_1_Weights.verify(weights)
223
+ return _squeezenet("1_1", weights, progress, **kwargs)
.venv/lib/python3.11/site-packages/torchvision/models/vgg.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, cast, Dict, List, Optional, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..transforms._presets import ImageClassification
8
+ from ..utils import _log_api_usage_once
9
+ from ._api import register_model, Weights, WeightsEnum
10
+ from ._meta import _IMAGENET_CATEGORIES
11
+ from ._utils import _ovewrite_named_param, handle_legacy_interface
12
+
13
+
14
+ __all__ = [
15
+ "VGG",
16
+ "VGG11_Weights",
17
+ "VGG11_BN_Weights",
18
+ "VGG13_Weights",
19
+ "VGG13_BN_Weights",
20
+ "VGG16_Weights",
21
+ "VGG16_BN_Weights",
22
+ "VGG19_Weights",
23
+ "VGG19_BN_Weights",
24
+ "vgg11",
25
+ "vgg11_bn",
26
+ "vgg13",
27
+ "vgg13_bn",
28
+ "vgg16",
29
+ "vgg16_bn",
30
+ "vgg19",
31
+ "vgg19_bn",
32
+ ]
33
+
34
+
35
+ class VGG(nn.Module):
36
+ def __init__(
37
+ self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5
38
+ ) -> None:
39
+ super().__init__()
40
+ _log_api_usage_once(self)
41
+ self.features = features
42
+ self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
43
+ self.classifier = nn.Sequential(
44
+ nn.Linear(512 * 7 * 7, 4096),
45
+ nn.ReLU(True),
46
+ nn.Dropout(p=dropout),
47
+ nn.Linear(4096, 4096),
48
+ nn.ReLU(True),
49
+ nn.Dropout(p=dropout),
50
+ nn.Linear(4096, num_classes),
51
+ )
52
+ if init_weights:
53
+ for m in self.modules():
54
+ if isinstance(m, nn.Conv2d):
55
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
56
+ if m.bias is not None:
57
+ nn.init.constant_(m.bias, 0)
58
+ elif isinstance(m, nn.BatchNorm2d):
59
+ nn.init.constant_(m.weight, 1)
60
+ nn.init.constant_(m.bias, 0)
61
+ elif isinstance(m, nn.Linear):
62
+ nn.init.normal_(m.weight, 0, 0.01)
63
+ nn.init.constant_(m.bias, 0)
64
+
65
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
66
+ x = self.features(x)
67
+ x = self.avgpool(x)
68
+ x = torch.flatten(x, 1)
69
+ x = self.classifier(x)
70
+ return x
71
+
72
+
73
+ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
74
+ layers: List[nn.Module] = []
75
+ in_channels = 3
76
+ for v in cfg:
77
+ if v == "M":
78
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
79
+ else:
80
+ v = cast(int, v)
81
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
82
+ if batch_norm:
83
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
84
+ else:
85
+ layers += [conv2d, nn.ReLU(inplace=True)]
86
+ in_channels = v
87
+ return nn.Sequential(*layers)
88
+
89
+
90
+ cfgs: Dict[str, List[Union[str, int]]] = {
91
+ "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
92
+ "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
93
+ "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
94
+ "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
95
+ }
96
+
97
+
98
+ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG:
99
+ if weights is not None:
100
+ kwargs["init_weights"] = False
101
+ if weights.meta["categories"] is not None:
102
+ _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
103
+ model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
104
+ if weights is not None:
105
+ model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
106
+ return model
107
+
108
+
109
+ _COMMON_META = {
110
+ "min_size": (32, 32),
111
+ "categories": _IMAGENET_CATEGORIES,
112
+ "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg",
113
+ "_docs": """These weights were trained from scratch by using a simplified training recipe.""",
114
+ }
115
+
116
+
117
+ class VGG11_Weights(WeightsEnum):
118
+ IMAGENET1K_V1 = Weights(
119
+ url="https://download.pytorch.org/models/vgg11-8a719046.pth",
120
+ transforms=partial(ImageClassification, crop_size=224),
121
+ meta={
122
+ **_COMMON_META,
123
+ "num_params": 132863336,
124
+ "_metrics": {
125
+ "ImageNet-1K": {
126
+ "acc@1": 69.020,
127
+ "acc@5": 88.628,
128
+ }
129
+ },
130
+ "_ops": 7.609,
131
+ "_file_size": 506.84,
132
+ },
133
+ )
134
+ DEFAULT = IMAGENET1K_V1
135
+
136
+
137
+ class VGG11_BN_Weights(WeightsEnum):
138
+ IMAGENET1K_V1 = Weights(
139
+ url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
140
+ transforms=partial(ImageClassification, crop_size=224),
141
+ meta={
142
+ **_COMMON_META,
143
+ "num_params": 132868840,
144
+ "_metrics": {
145
+ "ImageNet-1K": {
146
+ "acc@1": 70.370,
147
+ "acc@5": 89.810,
148
+ }
149
+ },
150
+ "_ops": 7.609,
151
+ "_file_size": 506.881,
152
+ },
153
+ )
154
+ DEFAULT = IMAGENET1K_V1
155
+
156
+
157
+ class VGG13_Weights(WeightsEnum):
158
+ IMAGENET1K_V1 = Weights(
159
+ url="https://download.pytorch.org/models/vgg13-19584684.pth",
160
+ transforms=partial(ImageClassification, crop_size=224),
161
+ meta={
162
+ **_COMMON_META,
163
+ "num_params": 133047848,
164
+ "_metrics": {
165
+ "ImageNet-1K": {
166
+ "acc@1": 69.928,
167
+ "acc@5": 89.246,
168
+ }
169
+ },
170
+ "_ops": 11.308,
171
+ "_file_size": 507.545,
172
+ },
173
+ )
174
+ DEFAULT = IMAGENET1K_V1
175
+
176
+
177
+ class VGG13_BN_Weights(WeightsEnum):
178
+ IMAGENET1K_V1 = Weights(
179
+ url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
180
+ transforms=partial(ImageClassification, crop_size=224),
181
+ meta={
182
+ **_COMMON_META,
183
+ "num_params": 133053736,
184
+ "_metrics": {
185
+ "ImageNet-1K": {
186
+ "acc@1": 71.586,
187
+ "acc@5": 90.374,
188
+ }
189
+ },
190
+ "_ops": 11.308,
191
+ "_file_size": 507.59,
192
+ },
193
+ )
194
+ DEFAULT = IMAGENET1K_V1
195
+
196
+
197
+ class VGG16_Weights(WeightsEnum):
198
+ IMAGENET1K_V1 = Weights(
199
+ url="https://download.pytorch.org/models/vgg16-397923af.pth",
200
+ transforms=partial(ImageClassification, crop_size=224),
201
+ meta={
202
+ **_COMMON_META,
203
+ "num_params": 138357544,
204
+ "_metrics": {
205
+ "ImageNet-1K": {
206
+ "acc@1": 71.592,
207
+ "acc@5": 90.382,
208
+ }
209
+ },
210
+ "_ops": 15.47,
211
+ "_file_size": 527.796,
212
+ },
213
+ )
214
+ IMAGENET1K_FEATURES = Weights(
215
+ # Weights ported from https://github.com/amdegroot/ssd.pytorch/
216
+ url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
217
+ transforms=partial(
218
+ ImageClassification,
219
+ crop_size=224,
220
+ mean=(0.48235, 0.45882, 0.40784),
221
+ std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
222
+ ),
223
+ meta={
224
+ **_COMMON_META,
225
+ "num_params": 138357544,
226
+ "categories": None,
227
+ "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd",
228
+ "_metrics": {
229
+ "ImageNet-1K": {
230
+ "acc@1": float("nan"),
231
+ "acc@5": float("nan"),
232
+ }
233
+ },
234
+ "_ops": 15.47,
235
+ "_file_size": 527.802,
236
+ "_docs": """
237
+ These weights can't be used for classification because they are missing values in the `classifier`
238
+ module. Only the `features` module has valid values and can be used for feature extraction. The weights
239
+ were trained using the original input standardization method as described in the paper.
240
+ """,
241
+ },
242
+ )
243
+ DEFAULT = IMAGENET1K_V1
244
+
245
+
246
+ class VGG16_BN_Weights(WeightsEnum):
247
+ IMAGENET1K_V1 = Weights(
248
+ url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
249
+ transforms=partial(ImageClassification, crop_size=224),
250
+ meta={
251
+ **_COMMON_META,
252
+ "num_params": 138365992,
253
+ "_metrics": {
254
+ "ImageNet-1K": {
255
+ "acc@1": 73.360,
256
+ "acc@5": 91.516,
257
+ }
258
+ },
259
+ "_ops": 15.47,
260
+ "_file_size": 527.866,
261
+ },
262
+ )
263
+ DEFAULT = IMAGENET1K_V1
264
+
265
+
266
+ class VGG19_Weights(WeightsEnum):
267
+ IMAGENET1K_V1 = Weights(
268
+ url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
269
+ transforms=partial(ImageClassification, crop_size=224),
270
+ meta={
271
+ **_COMMON_META,
272
+ "num_params": 143667240,
273
+ "_metrics": {
274
+ "ImageNet-1K": {
275
+ "acc@1": 72.376,
276
+ "acc@5": 90.876,
277
+ }
278
+ },
279
+ "_ops": 19.632,
280
+ "_file_size": 548.051,
281
+ },
282
+ )
283
+ DEFAULT = IMAGENET1K_V1
284
+
285
+
286
+ class VGG19_BN_Weights(WeightsEnum):
287
+ IMAGENET1K_V1 = Weights(
288
+ url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
289
+ transforms=partial(ImageClassification, crop_size=224),
290
+ meta={
291
+ **_COMMON_META,
292
+ "num_params": 143678248,
293
+ "_metrics": {
294
+ "ImageNet-1K": {
295
+ "acc@1": 74.218,
296
+ "acc@5": 91.842,
297
+ }
298
+ },
299
+ "_ops": 19.632,
300
+ "_file_size": 548.143,
301
+ },
302
+ )
303
+ DEFAULT = IMAGENET1K_V1
304
+
305
+
306
+ @register_model()
307
+ @handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1))
308
+ def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
309
+ """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
310
+
311
+ Args:
312
+ weights (:class:`~torchvision.models.VGG11_Weights`, optional): The
313
+ pretrained weights to use. See
314
+ :class:`~torchvision.models.VGG11_Weights` below for
315
+ more details, and possible values. By default, no pre-trained
316
+ weights are used.
317
+ progress (bool, optional): If True, displays a progress bar of the
318
+ download to stderr. Default is True.
319
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
320
+ base class. Please refer to the `source code
321
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
322
+ for more details about this class.
323
+
324
+ .. autoclass:: torchvision.models.VGG11_Weights
325
+ :members:
326
+ """
327
+ weights = VGG11_Weights.verify(weights)
328
+
329
+ return _vgg("A", False, weights, progress, **kwargs)
330
+
331
+
332
+ @register_model()
333
+ @handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1))
334
+ def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
335
+ """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
336
+
337
+ Args:
338
+ weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The
339
+ pretrained weights to use. See
340
+ :class:`~torchvision.models.VGG11_BN_Weights` below for
341
+ more details, and possible values. By default, no pre-trained
342
+ weights are used.
343
+ progress (bool, optional): If True, displays a progress bar of the
344
+ download to stderr. Default is True.
345
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
346
+ base class. Please refer to the `source code
347
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
348
+ for more details about this class.
349
+
350
+ .. autoclass:: torchvision.models.VGG11_BN_Weights
351
+ :members:
352
+ """
353
+ weights = VGG11_BN_Weights.verify(weights)
354
+
355
+ return _vgg("A", True, weights, progress, **kwargs)
356
+
357
+
358
+ @register_model()
359
+ @handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1))
360
+ def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
361
+ """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
362
+
363
+ Args:
364
+ weights (:class:`~torchvision.models.VGG13_Weights`, optional): The
365
+ pretrained weights to use. See
366
+ :class:`~torchvision.models.VGG13_Weights` below for
367
+ more details, and possible values. By default, no pre-trained
368
+ weights are used.
369
+ progress (bool, optional): If True, displays a progress bar of the
370
+ download to stderr. Default is True.
371
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
372
+ base class. Please refer to the `source code
373
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
374
+ for more details about this class.
375
+
376
+ .. autoclass:: torchvision.models.VGG13_Weights
377
+ :members:
378
+ """
379
+ weights = VGG13_Weights.verify(weights)
380
+
381
+ return _vgg("B", False, weights, progress, **kwargs)
382
+
383
+
384
+ @register_model()
385
+ @handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1))
386
+ def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
387
+ """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
388
+
389
+ Args:
390
+ weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The
391
+ pretrained weights to use. See
392
+ :class:`~torchvision.models.VGG13_BN_Weights` below for
393
+ more details, and possible values. By default, no pre-trained
394
+ weights are used.
395
+ progress (bool, optional): If True, displays a progress bar of the
396
+ download to stderr. Default is True.
397
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
398
+ base class. Please refer to the `source code
399
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
400
+ for more details about this class.
401
+
402
+ .. autoclass:: torchvision.models.VGG13_BN_Weights
403
+ :members:
404
+ """
405
+ weights = VGG13_BN_Weights.verify(weights)
406
+
407
+ return _vgg("B", True, weights, progress, **kwargs)
408
+
409
+
410
+ @register_model()
411
+ @handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1))
412
+ def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
413
+ """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
414
+
415
+ Args:
416
+ weights (:class:`~torchvision.models.VGG16_Weights`, optional): The
417
+ pretrained weights to use. See
418
+ :class:`~torchvision.models.VGG16_Weights` below for
419
+ more details, and possible values. By default, no pre-trained
420
+ weights are used.
421
+ progress (bool, optional): If True, displays a progress bar of the
422
+ download to stderr. Default is True.
423
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
424
+ base class. Please refer to the `source code
425
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
426
+ for more details about this class.
427
+
428
+ .. autoclass:: torchvision.models.VGG16_Weights
429
+ :members:
430
+ """
431
+ weights = VGG16_Weights.verify(weights)
432
+
433
+ return _vgg("D", False, weights, progress, **kwargs)
434
+
435
+
436
+ @register_model()
437
+ @handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1))
438
+ def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
439
+ """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
440
+
441
+ Args:
442
+ weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The
443
+ pretrained weights to use. See
444
+ :class:`~torchvision.models.VGG16_BN_Weights` below for
445
+ more details, and possible values. By default, no pre-trained
446
+ weights are used.
447
+ progress (bool, optional): If True, displays a progress bar of the
448
+ download to stderr. Default is True.
449
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
450
+ base class. Please refer to the `source code
451
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
452
+ for more details about this class.
453
+
454
+ .. autoclass:: torchvision.models.VGG16_BN_Weights
455
+ :members:
456
+ """
457
+ weights = VGG16_BN_Weights.verify(weights)
458
+
459
+ return _vgg("D", True, weights, progress, **kwargs)
460
+
461
+
462
+ @register_model()
463
+ @handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1))
464
+ def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
465
+ """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
466
+
467
+ Args:
468
+ weights (:class:`~torchvision.models.VGG19_Weights`, optional): The
469
+ pretrained weights to use. See
470
+ :class:`~torchvision.models.VGG19_Weights` below for
471
+ more details, and possible values. By default, no pre-trained
472
+ weights are used.
473
+ progress (bool, optional): If True, displays a progress bar of the
474
+ download to stderr. Default is True.
475
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
476
+ base class. Please refer to the `source code
477
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
478
+ for more details about this class.
479
+
480
+ .. autoclass:: torchvision.models.VGG19_Weights
481
+ :members:
482
+ """
483
+ weights = VGG19_Weights.verify(weights)
484
+
485
+ return _vgg("E", False, weights, progress, **kwargs)
486
+
487
+
488
+ @register_model()
489
+ @handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1))
490
+ def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
491
+ """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`__.
492
+
493
+ Args:
494
+ weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The
495
+ pretrained weights to use. See
496
+ :class:`~torchvision.models.VGG19_BN_Weights` below for
497
+ more details, and possible values. By default, no pre-trained
498
+ weights are used.
499
+ progress (bool, optional): If True, displays a progress bar of the
500
+ download to stderr. Default is True.
501
+ **kwargs: parameters passed to the ``torchvision.models.vgg.VGG``
502
+ base class. Please refer to the `source code
503
+ <https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py>`_
504
+ for more details about this class.
505
+
506
+ .. autoclass:: torchvision.models.VGG19_BN_Weights
507
+ :members:
508
+ """
509
+ weights = VGG19_BN_Weights.verify(weights)
510
+
511
+ return _vgg("E", True, weights, progress, **kwargs)
.venv/lib/python3.11/site-packages/torchvision/ops/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._register_onnx_ops import _register_custom_op
2
+ from .boxes import (
3
+ batched_nms,
4
+ box_area,
5
+ box_convert,
6
+ box_iou,
7
+ clip_boxes_to_image,
8
+ complete_box_iou,
9
+ distance_box_iou,
10
+ generalized_box_iou,
11
+ masks_to_boxes,
12
+ nms,
13
+ remove_small_boxes,
14
+ )
15
+ from .ciou_loss import complete_box_iou_loss
16
+ from .deform_conv import deform_conv2d, DeformConv2d
17
+ from .diou_loss import distance_box_iou_loss
18
+ from .drop_block import drop_block2d, drop_block3d, DropBlock2d, DropBlock3d
19
+ from .feature_pyramid_network import FeaturePyramidNetwork
20
+ from .focal_loss import sigmoid_focal_loss
21
+ from .giou_loss import generalized_box_iou_loss
22
+ from .misc import Conv2dNormActivation, Conv3dNormActivation, FrozenBatchNorm2d, MLP, Permute, SqueezeExcitation
23
+ from .poolers import MultiScaleRoIAlign
24
+ from .ps_roi_align import ps_roi_align, PSRoIAlign
25
+ from .ps_roi_pool import ps_roi_pool, PSRoIPool
26
+ from .roi_align import roi_align, RoIAlign
27
+ from .roi_pool import roi_pool, RoIPool
28
+ from .stochastic_depth import stochastic_depth, StochasticDepth
29
+
30
+ _register_custom_op()
31
+
32
+
33
+ __all__ = [
34
+ "masks_to_boxes",
35
+ "deform_conv2d",
36
+ "DeformConv2d",
37
+ "nms",
38
+ "batched_nms",
39
+ "remove_small_boxes",
40
+ "clip_boxes_to_image",
41
+ "box_convert",
42
+ "box_area",
43
+ "box_iou",
44
+ "generalized_box_iou",
45
+ "distance_box_iou",
46
+ "complete_box_iou",
47
+ "roi_align",
48
+ "RoIAlign",
49
+ "roi_pool",
50
+ "RoIPool",
51
+ "ps_roi_align",
52
+ "PSRoIAlign",
53
+ "ps_roi_pool",
54
+ "PSRoIPool",
55
+ "MultiScaleRoIAlign",
56
+ "FeaturePyramidNetwork",
57
+ "sigmoid_focal_loss",
58
+ "stochastic_depth",
59
+ "StochasticDepth",
60
+ "FrozenBatchNorm2d",
61
+ "Conv2dNormActivation",
62
+ "Conv3dNormActivation",
63
+ "SqueezeExcitation",
64
+ "MLP",
65
+ "Permute",
66
+ "generalized_box_iou_loss",
67
+ "distance_box_iou_loss",
68
+ "complete_box_iou_loss",
69
+ "drop_block2d",
70
+ "DropBlock2d",
71
+ "drop_block3d",
72
+ "DropBlock3d",
73
+ ]
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.22 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_box_convert.cpython-311.pyc ADDED
Binary file (3.46 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_register_onnx_ops.cpython-311.pyc ADDED
Binary file (6.19 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/_utils.cpython-311.pyc ADDED
Binary file (7.17 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/boxes.cpython-311.pyc ADDED
Binary file (23.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/ciou_loss.cpython-311.pyc ADDED
Binary file (3.94 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/deform_conv.cpython-311.pyc ADDED
Binary file (9.67 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/diou_loss.cpython-311.pyc ADDED
Binary file (4.3 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/drop_block.cpython-311.pyc ADDED
Binary file (9.13 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/feature_pyramid_network.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/focal_loss.cpython-311.pyc ADDED
Binary file (2.91 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/giou_loss.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/misc.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
.venv/lib/python3.11/site-packages/torchvision/ops/__pycache__/poolers.cpython-311.pyc ADDED
Binary file (16.5 kB). View file