Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__init__.py +0 -0
- pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_graphmatrix.py +276 -0
- pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_laplacian.py +336 -0
- pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_modularity.py +87 -0
- pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_spectrum.py +71 -0
- pythonProject/.venv/Lib/site-packages/numpy/ctypeslib.py +602 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h +103 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h +61 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h +84 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h +38 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h +475 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h +81 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h +169 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h +126 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h +124 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h +22 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h +209 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h +187 -0
- pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Macros.h +3 -0
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__init__.py
ADDED
|
File without changes
|
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_graphmatrix.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
np = pytest.importorskip("numpy")
|
| 4 |
+
pytest.importorskip("scipy")
|
| 5 |
+
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from networkx.exception import NetworkXError
|
| 8 |
+
from networkx.generators.degree_seq import havel_hakimi_graph
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_incidence_matrix_simple():
|
| 12 |
+
deg = [3, 2, 2, 1, 0]
|
| 13 |
+
G = havel_hakimi_graph(deg)
|
| 14 |
+
deg = [(1, 0), (1, 0), (1, 0), (2, 0), (1, 0), (2, 1), (0, 1), (0, 1)]
|
| 15 |
+
MG = nx.random_clustered_graph(deg, seed=42)
|
| 16 |
+
|
| 17 |
+
I = nx.incidence_matrix(G, dtype=int).todense()
|
| 18 |
+
# fmt: off
|
| 19 |
+
expected = np.array(
|
| 20 |
+
[[1, 1, 1, 0],
|
| 21 |
+
[0, 1, 0, 1],
|
| 22 |
+
[1, 0, 0, 1],
|
| 23 |
+
[0, 0, 1, 0],
|
| 24 |
+
[0, 0, 0, 0]]
|
| 25 |
+
)
|
| 26 |
+
# fmt: on
|
| 27 |
+
np.testing.assert_equal(I, expected)
|
| 28 |
+
|
| 29 |
+
I = nx.incidence_matrix(MG, dtype=int).todense()
|
| 30 |
+
# fmt: off
|
| 31 |
+
expected = np.array(
|
| 32 |
+
[[1, 0, 0, 0, 0, 0, 0],
|
| 33 |
+
[1, 0, 0, 0, 0, 0, 0],
|
| 34 |
+
[0, 1, 0, 0, 0, 0, 0],
|
| 35 |
+
[0, 0, 0, 0, 0, 0, 0],
|
| 36 |
+
[0, 1, 0, 0, 0, 0, 0],
|
| 37 |
+
[0, 0, 0, 0, 1, 1, 0],
|
| 38 |
+
[0, 0, 0, 0, 0, 1, 1],
|
| 39 |
+
[0, 0, 0, 0, 1, 0, 1]]
|
| 40 |
+
)
|
| 41 |
+
# fmt: on
|
| 42 |
+
np.testing.assert_equal(I, expected)
|
| 43 |
+
|
| 44 |
+
with pytest.raises(NetworkXError):
|
| 45 |
+
nx.incidence_matrix(G, nodelist=[0, 1])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TestGraphMatrix:
|
| 49 |
+
@classmethod
|
| 50 |
+
def setup_class(cls):
|
| 51 |
+
deg = [3, 2, 2, 1, 0]
|
| 52 |
+
cls.G = havel_hakimi_graph(deg)
|
| 53 |
+
# fmt: off
|
| 54 |
+
cls.OI = np.array(
|
| 55 |
+
[[-1, -1, -1, 0],
|
| 56 |
+
[1, 0, 0, -1],
|
| 57 |
+
[0, 1, 0, 1],
|
| 58 |
+
[0, 0, 1, 0],
|
| 59 |
+
[0, 0, 0, 0]]
|
| 60 |
+
)
|
| 61 |
+
cls.A = np.array(
|
| 62 |
+
[[0, 1, 1, 1, 0],
|
| 63 |
+
[1, 0, 1, 0, 0],
|
| 64 |
+
[1, 1, 0, 0, 0],
|
| 65 |
+
[1, 0, 0, 0, 0],
|
| 66 |
+
[0, 0, 0, 0, 0]]
|
| 67 |
+
)
|
| 68 |
+
# fmt: on
|
| 69 |
+
cls.WG = havel_hakimi_graph(deg)
|
| 70 |
+
cls.WG.add_edges_from(
|
| 71 |
+
(u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges()
|
| 72 |
+
)
|
| 73 |
+
# fmt: off
|
| 74 |
+
cls.WA = np.array(
|
| 75 |
+
[[0, 0.5, 0.5, 0.5, 0],
|
| 76 |
+
[0.5, 0, 0.5, 0, 0],
|
| 77 |
+
[0.5, 0.5, 0, 0, 0],
|
| 78 |
+
[0.5, 0, 0, 0, 0],
|
| 79 |
+
[0, 0, 0, 0, 0]]
|
| 80 |
+
)
|
| 81 |
+
# fmt: on
|
| 82 |
+
cls.MG = nx.MultiGraph(cls.G)
|
| 83 |
+
cls.MG2 = cls.MG.copy()
|
| 84 |
+
cls.MG2.add_edge(0, 1)
|
| 85 |
+
# fmt: off
|
| 86 |
+
cls.MG2A = np.array(
|
| 87 |
+
[[0, 2, 1, 1, 0],
|
| 88 |
+
[2, 0, 1, 0, 0],
|
| 89 |
+
[1, 1, 0, 0, 0],
|
| 90 |
+
[1, 0, 0, 0, 0],
|
| 91 |
+
[0, 0, 0, 0, 0]]
|
| 92 |
+
)
|
| 93 |
+
cls.MGOI = np.array(
|
| 94 |
+
[[-1, -1, -1, -1, 0],
|
| 95 |
+
[1, 1, 0, 0, -1],
|
| 96 |
+
[0, 0, 1, 0, 1],
|
| 97 |
+
[0, 0, 0, 1, 0],
|
| 98 |
+
[0, 0, 0, 0, 0]]
|
| 99 |
+
)
|
| 100 |
+
# fmt: on
|
| 101 |
+
cls.no_edges_G = nx.Graph([(1, 2), (3, 2, {"weight": 8})])
|
| 102 |
+
cls.no_edges_A = np.array([[0, 0], [0, 0]])
|
| 103 |
+
|
| 104 |
+
def test_incidence_matrix(self):
|
| 105 |
+
"Conversion to incidence matrix"
|
| 106 |
+
I = nx.incidence_matrix(
|
| 107 |
+
self.G,
|
| 108 |
+
nodelist=sorted(self.G),
|
| 109 |
+
edgelist=sorted(self.G.edges()),
|
| 110 |
+
oriented=True,
|
| 111 |
+
dtype=int,
|
| 112 |
+
).todense()
|
| 113 |
+
np.testing.assert_equal(I, self.OI)
|
| 114 |
+
|
| 115 |
+
I = nx.incidence_matrix(
|
| 116 |
+
self.G,
|
| 117 |
+
nodelist=sorted(self.G),
|
| 118 |
+
edgelist=sorted(self.G.edges()),
|
| 119 |
+
oriented=False,
|
| 120 |
+
dtype=int,
|
| 121 |
+
).todense()
|
| 122 |
+
np.testing.assert_equal(I, np.abs(self.OI))
|
| 123 |
+
|
| 124 |
+
I = nx.incidence_matrix(
|
| 125 |
+
self.MG,
|
| 126 |
+
nodelist=sorted(self.MG),
|
| 127 |
+
edgelist=sorted(self.MG.edges()),
|
| 128 |
+
oriented=True,
|
| 129 |
+
dtype=int,
|
| 130 |
+
).todense()
|
| 131 |
+
np.testing.assert_equal(I, self.OI)
|
| 132 |
+
|
| 133 |
+
I = nx.incidence_matrix(
|
| 134 |
+
self.MG,
|
| 135 |
+
nodelist=sorted(self.MG),
|
| 136 |
+
edgelist=sorted(self.MG.edges()),
|
| 137 |
+
oriented=False,
|
| 138 |
+
dtype=int,
|
| 139 |
+
).todense()
|
| 140 |
+
np.testing.assert_equal(I, np.abs(self.OI))
|
| 141 |
+
|
| 142 |
+
I = nx.incidence_matrix(
|
| 143 |
+
self.MG2,
|
| 144 |
+
nodelist=sorted(self.MG2),
|
| 145 |
+
edgelist=sorted(self.MG2.edges()),
|
| 146 |
+
oriented=True,
|
| 147 |
+
dtype=int,
|
| 148 |
+
).todense()
|
| 149 |
+
np.testing.assert_equal(I, self.MGOI)
|
| 150 |
+
|
| 151 |
+
I = nx.incidence_matrix(
|
| 152 |
+
self.MG2,
|
| 153 |
+
nodelist=sorted(self.MG),
|
| 154 |
+
edgelist=sorted(self.MG2.edges()),
|
| 155 |
+
oriented=False,
|
| 156 |
+
dtype=int,
|
| 157 |
+
).todense()
|
| 158 |
+
np.testing.assert_equal(I, np.abs(self.MGOI))
|
| 159 |
+
|
| 160 |
+
I = nx.incidence_matrix(self.G, dtype=np.uint8)
|
| 161 |
+
assert I.dtype == np.uint8
|
| 162 |
+
|
| 163 |
+
def test_weighted_incidence_matrix(self):
|
| 164 |
+
I = nx.incidence_matrix(
|
| 165 |
+
self.WG,
|
| 166 |
+
nodelist=sorted(self.WG),
|
| 167 |
+
edgelist=sorted(self.WG.edges()),
|
| 168 |
+
oriented=True,
|
| 169 |
+
dtype=int,
|
| 170 |
+
).todense()
|
| 171 |
+
np.testing.assert_equal(I, self.OI)
|
| 172 |
+
|
| 173 |
+
I = nx.incidence_matrix(
|
| 174 |
+
self.WG,
|
| 175 |
+
nodelist=sorted(self.WG),
|
| 176 |
+
edgelist=sorted(self.WG.edges()),
|
| 177 |
+
oriented=False,
|
| 178 |
+
dtype=int,
|
| 179 |
+
).todense()
|
| 180 |
+
np.testing.assert_equal(I, np.abs(self.OI))
|
| 181 |
+
|
| 182 |
+
# np.testing.assert_equal(nx.incidence_matrix(self.WG,oriented=True,
|
| 183 |
+
# weight='weight').todense(),0.5*self.OI)
|
| 184 |
+
# np.testing.assert_equal(nx.incidence_matrix(self.WG,weight='weight').todense(),
|
| 185 |
+
# np.abs(0.5*self.OI))
|
| 186 |
+
# np.testing.assert_equal(nx.incidence_matrix(self.WG,oriented=True,weight='other').todense(),
|
| 187 |
+
# 0.3*self.OI)
|
| 188 |
+
|
| 189 |
+
I = nx.incidence_matrix(
|
| 190 |
+
self.WG,
|
| 191 |
+
nodelist=sorted(self.WG),
|
| 192 |
+
edgelist=sorted(self.WG.edges()),
|
| 193 |
+
oriented=True,
|
| 194 |
+
weight="weight",
|
| 195 |
+
).todense()
|
| 196 |
+
np.testing.assert_equal(I, 0.5 * self.OI)
|
| 197 |
+
|
| 198 |
+
I = nx.incidence_matrix(
|
| 199 |
+
self.WG,
|
| 200 |
+
nodelist=sorted(self.WG),
|
| 201 |
+
edgelist=sorted(self.WG.edges()),
|
| 202 |
+
oriented=False,
|
| 203 |
+
weight="weight",
|
| 204 |
+
).todense()
|
| 205 |
+
np.testing.assert_equal(I, np.abs(0.5 * self.OI))
|
| 206 |
+
|
| 207 |
+
I = nx.incidence_matrix(
|
| 208 |
+
self.WG,
|
| 209 |
+
nodelist=sorted(self.WG),
|
| 210 |
+
edgelist=sorted(self.WG.edges()),
|
| 211 |
+
oriented=True,
|
| 212 |
+
weight="other",
|
| 213 |
+
).todense()
|
| 214 |
+
np.testing.assert_equal(I, 0.3 * self.OI)
|
| 215 |
+
|
| 216 |
+
# WMG=nx.MultiGraph(self.WG)
|
| 217 |
+
# WMG.add_edge(0,1,weight=0.5,other=0.3)
|
| 218 |
+
# np.testing.assert_equal(nx.incidence_matrix(WMG,weight='weight').todense(),
|
| 219 |
+
# np.abs(0.5*self.MGOI))
|
| 220 |
+
# np.testing.assert_equal(nx.incidence_matrix(WMG,weight='weight',oriented=True).todense(),
|
| 221 |
+
# 0.5*self.MGOI)
|
| 222 |
+
# np.testing.assert_equal(nx.incidence_matrix(WMG,weight='other',oriented=True).todense(),
|
| 223 |
+
# 0.3*self.MGOI)
|
| 224 |
+
|
| 225 |
+
WMG = nx.MultiGraph(self.WG)
|
| 226 |
+
WMG.add_edge(0, 1, weight=0.5, other=0.3)
|
| 227 |
+
|
| 228 |
+
I = nx.incidence_matrix(
|
| 229 |
+
WMG,
|
| 230 |
+
nodelist=sorted(WMG),
|
| 231 |
+
edgelist=sorted(WMG.edges(keys=True)),
|
| 232 |
+
oriented=True,
|
| 233 |
+
weight="weight",
|
| 234 |
+
).todense()
|
| 235 |
+
np.testing.assert_equal(I, 0.5 * self.MGOI)
|
| 236 |
+
|
| 237 |
+
I = nx.incidence_matrix(
|
| 238 |
+
WMG,
|
| 239 |
+
nodelist=sorted(WMG),
|
| 240 |
+
edgelist=sorted(WMG.edges(keys=True)),
|
| 241 |
+
oriented=False,
|
| 242 |
+
weight="weight",
|
| 243 |
+
).todense()
|
| 244 |
+
np.testing.assert_equal(I, np.abs(0.5 * self.MGOI))
|
| 245 |
+
|
| 246 |
+
I = nx.incidence_matrix(
|
| 247 |
+
WMG,
|
| 248 |
+
nodelist=sorted(WMG),
|
| 249 |
+
edgelist=sorted(WMG.edges(keys=True)),
|
| 250 |
+
oriented=True,
|
| 251 |
+
weight="other",
|
| 252 |
+
).todense()
|
| 253 |
+
np.testing.assert_equal(I, 0.3 * self.MGOI)
|
| 254 |
+
|
| 255 |
+
def test_adjacency_matrix(self):
|
| 256 |
+
"Conversion to adjacency matrix"
|
| 257 |
+
np.testing.assert_equal(nx.adjacency_matrix(self.G).todense(), self.A)
|
| 258 |
+
np.testing.assert_equal(nx.adjacency_matrix(self.MG).todense(), self.A)
|
| 259 |
+
np.testing.assert_equal(nx.adjacency_matrix(self.MG2).todense(), self.MG2A)
|
| 260 |
+
np.testing.assert_equal(
|
| 261 |
+
nx.adjacency_matrix(self.G, nodelist=[0, 1]).todense(), self.A[:2, :2]
|
| 262 |
+
)
|
| 263 |
+
np.testing.assert_equal(nx.adjacency_matrix(self.WG).todense(), self.WA)
|
| 264 |
+
np.testing.assert_equal(
|
| 265 |
+
nx.adjacency_matrix(self.WG, weight=None).todense(), self.A
|
| 266 |
+
)
|
| 267 |
+
np.testing.assert_equal(
|
| 268 |
+
nx.adjacency_matrix(self.MG2, weight=None).todense(), self.MG2A
|
| 269 |
+
)
|
| 270 |
+
np.testing.assert_equal(
|
| 271 |
+
nx.adjacency_matrix(self.WG, weight="other").todense(), 0.6 * self.WA
|
| 272 |
+
)
|
| 273 |
+
np.testing.assert_equal(
|
| 274 |
+
nx.adjacency_matrix(self.no_edges_G, nodelist=[1, 3]).todense(),
|
| 275 |
+
self.no_edges_A,
|
| 276 |
+
)
|
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_laplacian.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
np = pytest.importorskip("numpy")
|
| 4 |
+
pytest.importorskip("scipy")
|
| 5 |
+
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from networkx.generators.degree_seq import havel_hakimi_graph
|
| 8 |
+
from networkx.generators.expanders import margulis_gabber_galil_graph
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestLaplacian:
|
| 12 |
+
@classmethod
|
| 13 |
+
def setup_class(cls):
|
| 14 |
+
deg = [3, 2, 2, 1, 0]
|
| 15 |
+
cls.G = havel_hakimi_graph(deg)
|
| 16 |
+
cls.WG = nx.Graph(
|
| 17 |
+
(u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges()
|
| 18 |
+
)
|
| 19 |
+
cls.WG.add_node(4)
|
| 20 |
+
cls.MG = nx.MultiGraph(cls.G)
|
| 21 |
+
|
| 22 |
+
# Graph with clsloops
|
| 23 |
+
cls.Gsl = cls.G.copy()
|
| 24 |
+
for node in cls.Gsl.nodes():
|
| 25 |
+
cls.Gsl.add_edge(node, node)
|
| 26 |
+
|
| 27 |
+
# Graph used as an example in Sec. 4.1 of Langville and Meyer,
|
| 28 |
+
# "Google's PageRank and Beyond".
|
| 29 |
+
cls.DiG = nx.DiGraph()
|
| 30 |
+
cls.DiG.add_edges_from(
|
| 31 |
+
(
|
| 32 |
+
(1, 2),
|
| 33 |
+
(1, 3),
|
| 34 |
+
(3, 1),
|
| 35 |
+
(3, 2),
|
| 36 |
+
(3, 5),
|
| 37 |
+
(4, 5),
|
| 38 |
+
(4, 6),
|
| 39 |
+
(5, 4),
|
| 40 |
+
(5, 6),
|
| 41 |
+
(6, 4),
|
| 42 |
+
)
|
| 43 |
+
)
|
| 44 |
+
cls.DiMG = nx.MultiDiGraph(cls.DiG)
|
| 45 |
+
cls.DiWG = nx.DiGraph(
|
| 46 |
+
(u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.DiG.edges()
|
| 47 |
+
)
|
| 48 |
+
cls.DiGsl = cls.DiG.copy()
|
| 49 |
+
for node in cls.DiGsl.nodes():
|
| 50 |
+
cls.DiGsl.add_edge(node, node)
|
| 51 |
+
|
| 52 |
+
def test_laplacian(self):
|
| 53 |
+
"Graph Laplacian"
|
| 54 |
+
# fmt: off
|
| 55 |
+
NL = np.array([[ 3, -1, -1, -1, 0],
|
| 56 |
+
[-1, 2, -1, 0, 0],
|
| 57 |
+
[-1, -1, 2, 0, 0],
|
| 58 |
+
[-1, 0, 0, 1, 0],
|
| 59 |
+
[ 0, 0, 0, 0, 0]])
|
| 60 |
+
# fmt: on
|
| 61 |
+
WL = 0.5 * NL
|
| 62 |
+
OL = 0.3 * NL
|
| 63 |
+
# fmt: off
|
| 64 |
+
DiNL = np.array([[ 2, -1, -1, 0, 0, 0],
|
| 65 |
+
[ 0, 0, 0, 0, 0, 0],
|
| 66 |
+
[-1, -1, 3, -1, 0, 0],
|
| 67 |
+
[ 0, 0, 0, 2, -1, -1],
|
| 68 |
+
[ 0, 0, 0, -1, 2, -1],
|
| 69 |
+
[ 0, 0, 0, 0, -1, 1]])
|
| 70 |
+
# fmt: on
|
| 71 |
+
DiWL = 0.5 * DiNL
|
| 72 |
+
DiOL = 0.3 * DiNL
|
| 73 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.G).todense(), NL)
|
| 74 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.MG).todense(), NL)
|
| 75 |
+
np.testing.assert_equal(
|
| 76 |
+
nx.laplacian_matrix(self.G, nodelist=[0, 1]).todense(),
|
| 77 |
+
np.array([[1, -1], [-1, 1]]),
|
| 78 |
+
)
|
| 79 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.WG).todense(), WL)
|
| 80 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.WG, weight=None).todense(), NL)
|
| 81 |
+
np.testing.assert_equal(
|
| 82 |
+
nx.laplacian_matrix(self.WG, weight="other").todense(), OL
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.DiG).todense(), DiNL)
|
| 86 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.DiMG).todense(), DiNL)
|
| 87 |
+
np.testing.assert_equal(
|
| 88 |
+
nx.laplacian_matrix(self.DiG, nodelist=[1, 2]).todense(),
|
| 89 |
+
np.array([[1, -1], [0, 0]]),
|
| 90 |
+
)
|
| 91 |
+
np.testing.assert_equal(nx.laplacian_matrix(self.DiWG).todense(), DiWL)
|
| 92 |
+
np.testing.assert_equal(
|
| 93 |
+
nx.laplacian_matrix(self.DiWG, weight=None).todense(), DiNL
|
| 94 |
+
)
|
| 95 |
+
np.testing.assert_equal(
|
| 96 |
+
nx.laplacian_matrix(self.DiWG, weight="other").todense(), DiOL
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def test_normalized_laplacian(self):
|
| 100 |
+
"Generalized Graph Laplacian"
|
| 101 |
+
# fmt: off
|
| 102 |
+
G = np.array([[ 1. , -0.408, -0.408, -0.577, 0.],
|
| 103 |
+
[-0.408, 1. , -0.5 , 0. , 0.],
|
| 104 |
+
[-0.408, -0.5 , 1. , 0. , 0.],
|
| 105 |
+
[-0.577, 0. , 0. , 1. , 0.],
|
| 106 |
+
[ 0. , 0. , 0. , 0. , 0.]])
|
| 107 |
+
GL = np.array([[ 1. , -0.408, -0.408, -0.577, 0. ],
|
| 108 |
+
[-0.408, 1. , -0.5 , 0. , 0. ],
|
| 109 |
+
[-0.408, -0.5 , 1. , 0. , 0. ],
|
| 110 |
+
[-0.577, 0. , 0. , 1. , 0. ],
|
| 111 |
+
[ 0. , 0. , 0. , 0. , 0. ]])
|
| 112 |
+
Lsl = np.array([[ 0.75 , -0.2887, -0.2887, -0.3536, 0. ],
|
| 113 |
+
[-0.2887, 0.6667, -0.3333, 0. , 0. ],
|
| 114 |
+
[-0.2887, -0.3333, 0.6667, 0. , 0. ],
|
| 115 |
+
[-0.3536, 0. , 0. , 0.5 , 0. ],
|
| 116 |
+
[ 0. , 0. , 0. , 0. , 0. ]])
|
| 117 |
+
|
| 118 |
+
DiG = np.array([[ 1. , 0. , -0.4082, 0. , 0. , 0. ],
|
| 119 |
+
[ 0. , 0. , 0. , 0. , 0. , 0. ],
|
| 120 |
+
[-0.4082, 0. , 1. , 0. , -0.4082, 0. ],
|
| 121 |
+
[ 0. , 0. , 0. , 1. , -0.5 , -0.7071],
|
| 122 |
+
[ 0. , 0. , 0. , -0.5 , 1. , -0.7071],
|
| 123 |
+
[ 0. , 0. , 0. , -0.7071, 0. , 1. ]])
|
| 124 |
+
DiGL = np.array([[ 1. , 0. , -0.4082, 0. , 0. , 0. ],
|
| 125 |
+
[ 0. , 0. , 0. , 0. , 0. , 0. ],
|
| 126 |
+
[-0.4082, 0. , 1. , -0.4082, 0. , 0. ],
|
| 127 |
+
[ 0. , 0. , 0. , 1. , -0.5 , -0.7071],
|
| 128 |
+
[ 0. , 0. , 0. , -0.5 , 1. , -0.7071],
|
| 129 |
+
[ 0. , 0. , 0. , 0. , -0.7071, 1. ]])
|
| 130 |
+
DiLsl = np.array([[ 0.6667, -0.5774, -0.2887, 0. , 0. , 0. ],
|
| 131 |
+
[ 0. , 0. , 0. , 0. , 0. , 0. ],
|
| 132 |
+
[-0.2887, -0.5 , 0.75 , -0.2887, 0. , 0. ],
|
| 133 |
+
[ 0. , 0. , 0. , 0.6667, -0.3333, -0.4082],
|
| 134 |
+
[ 0. , 0. , 0. , -0.3333, 0.6667, -0.4082],
|
| 135 |
+
[ 0. , 0. , 0. , 0. , -0.4082, 0.5 ]])
|
| 136 |
+
# fmt: on
|
| 137 |
+
|
| 138 |
+
np.testing.assert_almost_equal(
|
| 139 |
+
nx.normalized_laplacian_matrix(self.G, nodelist=range(5)).todense(),
|
| 140 |
+
G,
|
| 141 |
+
decimal=3,
|
| 142 |
+
)
|
| 143 |
+
np.testing.assert_almost_equal(
|
| 144 |
+
nx.normalized_laplacian_matrix(self.G).todense(), GL, decimal=3
|
| 145 |
+
)
|
| 146 |
+
np.testing.assert_almost_equal(
|
| 147 |
+
nx.normalized_laplacian_matrix(self.MG).todense(), GL, decimal=3
|
| 148 |
+
)
|
| 149 |
+
np.testing.assert_almost_equal(
|
| 150 |
+
nx.normalized_laplacian_matrix(self.WG).todense(), GL, decimal=3
|
| 151 |
+
)
|
| 152 |
+
np.testing.assert_almost_equal(
|
| 153 |
+
nx.normalized_laplacian_matrix(self.WG, weight="other").todense(),
|
| 154 |
+
GL,
|
| 155 |
+
decimal=3,
|
| 156 |
+
)
|
| 157 |
+
np.testing.assert_almost_equal(
|
| 158 |
+
nx.normalized_laplacian_matrix(self.Gsl).todense(), Lsl, decimal=3
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
np.testing.assert_almost_equal(
|
| 162 |
+
nx.normalized_laplacian_matrix(
|
| 163 |
+
self.DiG,
|
| 164 |
+
nodelist=range(1, 1 + 6),
|
| 165 |
+
).todense(),
|
| 166 |
+
DiG,
|
| 167 |
+
decimal=3,
|
| 168 |
+
)
|
| 169 |
+
np.testing.assert_almost_equal(
|
| 170 |
+
nx.normalized_laplacian_matrix(self.DiG).todense(), DiGL, decimal=3
|
| 171 |
+
)
|
| 172 |
+
np.testing.assert_almost_equal(
|
| 173 |
+
nx.normalized_laplacian_matrix(self.DiMG).todense(), DiGL, decimal=3
|
| 174 |
+
)
|
| 175 |
+
np.testing.assert_almost_equal(
|
| 176 |
+
nx.normalized_laplacian_matrix(self.DiWG).todense(), DiGL, decimal=3
|
| 177 |
+
)
|
| 178 |
+
np.testing.assert_almost_equal(
|
| 179 |
+
nx.normalized_laplacian_matrix(self.DiWG, weight="other").todense(),
|
| 180 |
+
DiGL,
|
| 181 |
+
decimal=3,
|
| 182 |
+
)
|
| 183 |
+
np.testing.assert_almost_equal(
|
| 184 |
+
nx.normalized_laplacian_matrix(self.DiGsl).todense(), DiLsl, decimal=3
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def test_directed_laplacian():
|
| 189 |
+
"Directed Laplacian"
|
| 190 |
+
# Graph used as an example in Sec. 4.1 of Langville and Meyer,
|
| 191 |
+
# "Google's PageRank and Beyond". The graph contains dangling nodes, so
|
| 192 |
+
# the pagerank random walk is selected by directed_laplacian
|
| 193 |
+
G = nx.DiGraph()
|
| 194 |
+
G.add_edges_from(
|
| 195 |
+
(
|
| 196 |
+
(1, 2),
|
| 197 |
+
(1, 3),
|
| 198 |
+
(3, 1),
|
| 199 |
+
(3, 2),
|
| 200 |
+
(3, 5),
|
| 201 |
+
(4, 5),
|
| 202 |
+
(4, 6),
|
| 203 |
+
(5, 4),
|
| 204 |
+
(5, 6),
|
| 205 |
+
(6, 4),
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
# fmt: off
|
| 209 |
+
GL = np.array([[ 0.9833, -0.2941, -0.3882, -0.0291, -0.0231, -0.0261],
|
| 210 |
+
[-0.2941, 0.8333, -0.2339, -0.0536, -0.0589, -0.0554],
|
| 211 |
+
[-0.3882, -0.2339, 0.9833, -0.0278, -0.0896, -0.0251],
|
| 212 |
+
[-0.0291, -0.0536, -0.0278, 0.9833, -0.4878, -0.6675],
|
| 213 |
+
[-0.0231, -0.0589, -0.0896, -0.4878, 0.9833, -0.2078],
|
| 214 |
+
[-0.0261, -0.0554, -0.0251, -0.6675, -0.2078, 0.9833]])
|
| 215 |
+
# fmt: on
|
| 216 |
+
L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G))
|
| 217 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 218 |
+
|
| 219 |
+
# Make the graph strongly connected, so we can use a random and lazy walk
|
| 220 |
+
G.add_edges_from(((2, 5), (6, 1)))
|
| 221 |
+
# fmt: off
|
| 222 |
+
GL = np.array([[ 1. , -0.3062, -0.4714, 0. , 0. , -0.3227],
|
| 223 |
+
[-0.3062, 1. , -0.1443, 0. , -0.3162, 0. ],
|
| 224 |
+
[-0.4714, -0.1443, 1. , 0. , -0.0913, 0. ],
|
| 225 |
+
[ 0. , 0. , 0. , 1. , -0.5 , -0.5 ],
|
| 226 |
+
[ 0. , -0.3162, -0.0913, -0.5 , 1. , -0.25 ],
|
| 227 |
+
[-0.3227, 0. , 0. , -0.5 , -0.25 , 1. ]])
|
| 228 |
+
# fmt: on
|
| 229 |
+
L = nx.directed_laplacian_matrix(
|
| 230 |
+
G, alpha=0.9, nodelist=sorted(G), walk_type="random"
|
| 231 |
+
)
|
| 232 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 233 |
+
|
| 234 |
+
# fmt: off
|
| 235 |
+
GL = np.array([[ 0.5 , -0.1531, -0.2357, 0. , 0. , -0.1614],
|
| 236 |
+
[-0.1531, 0.5 , -0.0722, 0. , -0.1581, 0. ],
|
| 237 |
+
[-0.2357, -0.0722, 0.5 , 0. , -0.0456, 0. ],
|
| 238 |
+
[ 0. , 0. , 0. , 0.5 , -0.25 , -0.25 ],
|
| 239 |
+
[ 0. , -0.1581, -0.0456, -0.25 , 0.5 , -0.125 ],
|
| 240 |
+
[-0.1614, 0. , 0. , -0.25 , -0.125 , 0.5 ]])
|
| 241 |
+
# fmt: on
|
| 242 |
+
L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G), walk_type="lazy")
|
| 243 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 244 |
+
|
| 245 |
+
# Make a strongly connected periodic graph
|
| 246 |
+
G = nx.DiGraph()
|
| 247 |
+
G.add_edges_from(((1, 2), (2, 4), (4, 1), (1, 3), (3, 4)))
|
| 248 |
+
# fmt: off
|
| 249 |
+
GL = np.array([[ 0.5 , -0.176, -0.176, -0.25 ],
|
| 250 |
+
[-0.176, 0.5 , 0. , -0.176],
|
| 251 |
+
[-0.176, 0. , 0.5 , -0.176],
|
| 252 |
+
[-0.25 , -0.176, -0.176, 0.5 ]])
|
| 253 |
+
# fmt: on
|
| 254 |
+
L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G))
|
| 255 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def test_directed_combinatorial_laplacian():
|
| 259 |
+
"Directed combinatorial Laplacian"
|
| 260 |
+
# Graph used as an example in Sec. 4.1 of Langville and Meyer,
|
| 261 |
+
# "Google's PageRank and Beyond". The graph contains dangling nodes, so
|
| 262 |
+
# the pagerank random walk is selected by directed_laplacian
|
| 263 |
+
G = nx.DiGraph()
|
| 264 |
+
G.add_edges_from(
|
| 265 |
+
(
|
| 266 |
+
(1, 2),
|
| 267 |
+
(1, 3),
|
| 268 |
+
(3, 1),
|
| 269 |
+
(3, 2),
|
| 270 |
+
(3, 5),
|
| 271 |
+
(4, 5),
|
| 272 |
+
(4, 6),
|
| 273 |
+
(5, 4),
|
| 274 |
+
(5, 6),
|
| 275 |
+
(6, 4),
|
| 276 |
+
)
|
| 277 |
+
)
|
| 278 |
+
# fmt: off
|
| 279 |
+
GL = np.array([[ 0.0366, -0.0132, -0.0153, -0.0034, -0.0020, -0.0027],
|
| 280 |
+
[-0.0132, 0.0450, -0.0111, -0.0076, -0.0062, -0.0069],
|
| 281 |
+
[-0.0153, -0.0111, 0.0408, -0.0035, -0.0083, -0.0027],
|
| 282 |
+
[-0.0034, -0.0076, -0.0035, 0.3688, -0.1356, -0.2187],
|
| 283 |
+
[-0.0020, -0.0062, -0.0083, -0.1356, 0.2026, -0.0505],
|
| 284 |
+
[-0.0027, -0.0069, -0.0027, -0.2187, -0.0505, 0.2815]])
|
| 285 |
+
# fmt: on
|
| 286 |
+
|
| 287 |
+
L = nx.directed_combinatorial_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G))
|
| 288 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 289 |
+
|
| 290 |
+
# Make the graph strongly connected, so we can use a random and lazy walk
|
| 291 |
+
G.add_edges_from(((2, 5), (6, 1)))
|
| 292 |
+
|
| 293 |
+
# fmt: off
|
| 294 |
+
GL = np.array([[ 0.1395, -0.0349, -0.0465, 0. , 0. , -0.0581],
|
| 295 |
+
[-0.0349, 0.093 , -0.0116, 0. , -0.0465, 0. ],
|
| 296 |
+
[-0.0465, -0.0116, 0.0698, 0. , -0.0116, 0. ],
|
| 297 |
+
[ 0. , 0. , 0. , 0.2326, -0.1163, -0.1163],
|
| 298 |
+
[ 0. , -0.0465, -0.0116, -0.1163, 0.2326, -0.0581],
|
| 299 |
+
[-0.0581, 0. , 0. , -0.1163, -0.0581, 0.2326]])
|
| 300 |
+
# fmt: on
|
| 301 |
+
|
| 302 |
+
L = nx.directed_combinatorial_laplacian_matrix(
|
| 303 |
+
G, alpha=0.9, nodelist=sorted(G), walk_type="random"
|
| 304 |
+
)
|
| 305 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 306 |
+
|
| 307 |
+
# fmt: off
|
| 308 |
+
GL = np.array([[ 0.0698, -0.0174, -0.0233, 0. , 0. , -0.0291],
|
| 309 |
+
[-0.0174, 0.0465, -0.0058, 0. , -0.0233, 0. ],
|
| 310 |
+
[-0.0233, -0.0058, 0.0349, 0. , -0.0058, 0. ],
|
| 311 |
+
[ 0. , 0. , 0. , 0.1163, -0.0581, -0.0581],
|
| 312 |
+
[ 0. , -0.0233, -0.0058, -0.0581, 0.1163, -0.0291],
|
| 313 |
+
[-0.0291, 0. , 0. , -0.0581, -0.0291, 0.1163]])
|
| 314 |
+
# fmt: on
|
| 315 |
+
|
| 316 |
+
L = nx.directed_combinatorial_laplacian_matrix(
|
| 317 |
+
G, alpha=0.9, nodelist=sorted(G), walk_type="lazy"
|
| 318 |
+
)
|
| 319 |
+
np.testing.assert_almost_equal(L, GL, decimal=3)
|
| 320 |
+
|
| 321 |
+
E = nx.DiGraph(margulis_gabber_galil_graph(2))
|
| 322 |
+
L = nx.directed_combinatorial_laplacian_matrix(E)
|
| 323 |
+
# fmt: off
|
| 324 |
+
expected = np.array(
|
| 325 |
+
[[ 0.16666667, -0.08333333, -0.08333333, 0. ],
|
| 326 |
+
[-0.08333333, 0.16666667, 0. , -0.08333333],
|
| 327 |
+
[-0.08333333, 0. , 0.16666667, -0.08333333],
|
| 328 |
+
[ 0. , -0.08333333, -0.08333333, 0.16666667]]
|
| 329 |
+
)
|
| 330 |
+
# fmt: on
|
| 331 |
+
np.testing.assert_almost_equal(L, expected, decimal=6)
|
| 332 |
+
|
| 333 |
+
with pytest.raises(nx.NetworkXError):
|
| 334 |
+
nx.directed_combinatorial_laplacian_matrix(G, walk_type="pagerank", alpha=100)
|
| 335 |
+
with pytest.raises(nx.NetworkXError):
|
| 336 |
+
nx.directed_combinatorial_laplacian_matrix(G, walk_type="silly")
|
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_modularity.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
np = pytest.importorskip("numpy")
|
| 4 |
+
pytest.importorskip("scipy")
|
| 5 |
+
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from networkx.generators.degree_seq import havel_hakimi_graph
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestModularity:
|
| 11 |
+
@classmethod
|
| 12 |
+
def setup_class(cls):
|
| 13 |
+
deg = [3, 2, 2, 1, 0]
|
| 14 |
+
cls.G = havel_hakimi_graph(deg)
|
| 15 |
+
# Graph used as an example in Sec. 4.1 of Langville and Meyer,
|
| 16 |
+
# "Google's PageRank and Beyond". (Used for test_directed_laplacian)
|
| 17 |
+
cls.DG = nx.DiGraph()
|
| 18 |
+
cls.DG.add_edges_from(
|
| 19 |
+
(
|
| 20 |
+
(1, 2),
|
| 21 |
+
(1, 3),
|
| 22 |
+
(3, 1),
|
| 23 |
+
(3, 2),
|
| 24 |
+
(3, 5),
|
| 25 |
+
(4, 5),
|
| 26 |
+
(4, 6),
|
| 27 |
+
(5, 4),
|
| 28 |
+
(5, 6),
|
| 29 |
+
(6, 4),
|
| 30 |
+
)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def test_modularity(self):
|
| 34 |
+
"Modularity matrix"
|
| 35 |
+
# fmt: off
|
| 36 |
+
B = np.array([[-1.125, 0.25, 0.25, 0.625, 0.],
|
| 37 |
+
[0.25, -0.5, 0.5, -0.25, 0.],
|
| 38 |
+
[0.25, 0.5, -0.5, -0.25, 0.],
|
| 39 |
+
[0.625, -0.25, -0.25, -0.125, 0.],
|
| 40 |
+
[0., 0., 0., 0., 0.]])
|
| 41 |
+
# fmt: on
|
| 42 |
+
|
| 43 |
+
permutation = [4, 0, 1, 2, 3]
|
| 44 |
+
np.testing.assert_equal(nx.modularity_matrix(self.G), B)
|
| 45 |
+
np.testing.assert_equal(
|
| 46 |
+
nx.modularity_matrix(self.G, nodelist=permutation),
|
| 47 |
+
B[np.ix_(permutation, permutation)],
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def test_modularity_weight(self):
|
| 51 |
+
"Modularity matrix with weights"
|
| 52 |
+
# fmt: off
|
| 53 |
+
B = np.array([[-1.125, 0.25, 0.25, 0.625, 0.],
|
| 54 |
+
[0.25, -0.5, 0.5, -0.25, 0.],
|
| 55 |
+
[0.25, 0.5, -0.5, -0.25, 0.],
|
| 56 |
+
[0.625, -0.25, -0.25, -0.125, 0.],
|
| 57 |
+
[0., 0., 0., 0., 0.]])
|
| 58 |
+
# fmt: on
|
| 59 |
+
|
| 60 |
+
G_weighted = self.G.copy()
|
| 61 |
+
for n1, n2 in G_weighted.edges():
|
| 62 |
+
G_weighted.edges[n1, n2]["weight"] = 0.5
|
| 63 |
+
# The following test would fail in networkx 1.1
|
| 64 |
+
np.testing.assert_equal(nx.modularity_matrix(G_weighted), B)
|
| 65 |
+
# The following test that the modularity matrix get rescaled accordingly
|
| 66 |
+
np.testing.assert_equal(
|
| 67 |
+
nx.modularity_matrix(G_weighted, weight="weight"), 0.5 * B
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def test_directed_modularity(self):
|
| 71 |
+
"Directed Modularity matrix"
|
| 72 |
+
# fmt: off
|
| 73 |
+
B = np.array([[-0.2, 0.6, 0.8, -0.4, -0.4, -0.4],
|
| 74 |
+
[0., 0., 0., 0., 0., 0.],
|
| 75 |
+
[0.7, 0.4, -0.3, -0.6, 0.4, -0.6],
|
| 76 |
+
[-0.2, -0.4, -0.2, -0.4, 0.6, 0.6],
|
| 77 |
+
[-0.2, -0.4, -0.2, 0.6, -0.4, 0.6],
|
| 78 |
+
[-0.1, -0.2, -0.1, 0.8, -0.2, -0.2]])
|
| 79 |
+
# fmt: on
|
| 80 |
+
node_permutation = [5, 1, 2, 3, 4, 6]
|
| 81 |
+
idx_permutation = [4, 0, 1, 2, 3, 5]
|
| 82 |
+
mm = nx.directed_modularity_matrix(self.DG, nodelist=sorted(self.DG))
|
| 83 |
+
np.testing.assert_equal(mm, B)
|
| 84 |
+
np.testing.assert_equal(
|
| 85 |
+
nx.directed_modularity_matrix(self.DG, nodelist=node_permutation),
|
| 86 |
+
B[np.ix_(idx_permutation, idx_permutation)],
|
| 87 |
+
)
|
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_spectrum.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
np = pytest.importorskip("numpy")
|
| 4 |
+
pytest.importorskip("scipy")
|
| 5 |
+
|
| 6 |
+
import networkx as nx
|
| 7 |
+
from networkx.generators.degree_seq import havel_hakimi_graph
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestSpectrum:
|
| 11 |
+
@classmethod
|
| 12 |
+
def setup_class(cls):
|
| 13 |
+
deg = [3, 2, 2, 1, 0]
|
| 14 |
+
cls.G = havel_hakimi_graph(deg)
|
| 15 |
+
cls.P = nx.path_graph(3)
|
| 16 |
+
cls.WG = nx.Graph(
|
| 17 |
+
(u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges()
|
| 18 |
+
)
|
| 19 |
+
cls.WG.add_node(4)
|
| 20 |
+
cls.DG = nx.DiGraph()
|
| 21 |
+
nx.add_path(cls.DG, [0, 1, 2])
|
| 22 |
+
|
| 23 |
+
def test_laplacian_spectrum(self):
|
| 24 |
+
"Laplacian eigenvalues"
|
| 25 |
+
evals = np.array([0, 0, 1, 3, 4])
|
| 26 |
+
e = sorted(nx.laplacian_spectrum(self.G))
|
| 27 |
+
np.testing.assert_almost_equal(e, evals)
|
| 28 |
+
e = sorted(nx.laplacian_spectrum(self.WG, weight=None))
|
| 29 |
+
np.testing.assert_almost_equal(e, evals)
|
| 30 |
+
e = sorted(nx.laplacian_spectrum(self.WG))
|
| 31 |
+
np.testing.assert_almost_equal(e, 0.5 * evals)
|
| 32 |
+
e = sorted(nx.laplacian_spectrum(self.WG, weight="other"))
|
| 33 |
+
np.testing.assert_almost_equal(e, 0.3 * evals)
|
| 34 |
+
|
| 35 |
+
def test_normalized_laplacian_spectrum(self):
|
| 36 |
+
"Normalized Laplacian eigenvalues"
|
| 37 |
+
evals = np.array([0, 0, 0.7712864461218, 1.5, 1.7287135538781])
|
| 38 |
+
e = sorted(nx.normalized_laplacian_spectrum(self.G))
|
| 39 |
+
np.testing.assert_almost_equal(e, evals)
|
| 40 |
+
e = sorted(nx.normalized_laplacian_spectrum(self.WG, weight=None))
|
| 41 |
+
np.testing.assert_almost_equal(e, evals)
|
| 42 |
+
e = sorted(nx.normalized_laplacian_spectrum(self.WG))
|
| 43 |
+
np.testing.assert_almost_equal(e, evals)
|
| 44 |
+
e = sorted(nx.normalized_laplacian_spectrum(self.WG, weight="other"))
|
| 45 |
+
np.testing.assert_almost_equal(e, evals)
|
| 46 |
+
|
| 47 |
+
def test_adjacency_spectrum(self):
|
| 48 |
+
"Adjacency eigenvalues"
|
| 49 |
+
evals = np.array([-np.sqrt(2), 0, np.sqrt(2)])
|
| 50 |
+
e = sorted(nx.adjacency_spectrum(self.P))
|
| 51 |
+
np.testing.assert_almost_equal(e, evals)
|
| 52 |
+
|
| 53 |
+
def test_modularity_spectrum(self):
|
| 54 |
+
"Modularity eigenvalues"
|
| 55 |
+
evals = np.array([-1.5, 0.0, 0.0])
|
| 56 |
+
e = sorted(nx.modularity_spectrum(self.P))
|
| 57 |
+
np.testing.assert_almost_equal(e, evals)
|
| 58 |
+
# Directed modularity eigenvalues
|
| 59 |
+
evals = np.array([-0.5, 0.0, 0.0])
|
| 60 |
+
e = sorted(nx.modularity_spectrum(self.DG))
|
| 61 |
+
np.testing.assert_almost_equal(e, evals)
|
| 62 |
+
|
| 63 |
+
def test_bethe_hessian_spectrum(self):
|
| 64 |
+
"Bethe Hessian eigenvalues"
|
| 65 |
+
evals = np.array([0.5 * (9 - np.sqrt(33)), 4, 0.5 * (9 + np.sqrt(33))])
|
| 66 |
+
e = sorted(nx.bethe_hessian_spectrum(self.P, r=2))
|
| 67 |
+
np.testing.assert_almost_equal(e, evals)
|
| 68 |
+
# Collapses back to Laplacian:
|
| 69 |
+
e1 = sorted(nx.bethe_hessian_spectrum(self.P, r=1))
|
| 70 |
+
e2 = sorted(nx.laplacian_spectrum(self.P))
|
| 71 |
+
np.testing.assert_almost_equal(e1, e2)
|
pythonProject/.venv/Lib/site-packages/numpy/ctypeslib.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
============================
|
| 3 |
+
``ctypes`` Utility Functions
|
| 4 |
+
============================
|
| 5 |
+
|
| 6 |
+
See Also
|
| 7 |
+
--------
|
| 8 |
+
load_library : Load a C library.
|
| 9 |
+
ndpointer : Array restype/argtype with verification.
|
| 10 |
+
as_ctypes : Create a ctypes array from an ndarray.
|
| 11 |
+
as_array : Create an ndarray from a ctypes array.
|
| 12 |
+
|
| 13 |
+
References
|
| 14 |
+
----------
|
| 15 |
+
.. [1] "SciPy Cookbook: ctypes", https://scipy-cookbook.readthedocs.io/items/Ctypes.html
|
| 16 |
+
|
| 17 |
+
Examples
|
| 18 |
+
--------
|
| 19 |
+
Load the C library:
|
| 20 |
+
|
| 21 |
+
>>> _lib = np.ctypeslib.load_library('libmystuff', '.') #doctest: +SKIP
|
| 22 |
+
|
| 23 |
+
Our result type, an ndarray that must be of type double, be 1-dimensional
|
| 24 |
+
and is C-contiguous in memory:
|
| 25 |
+
|
| 26 |
+
>>> array_1d_double = np.ctypeslib.ndpointer(
|
| 27 |
+
... dtype=np.double,
|
| 28 |
+
... ndim=1, flags='CONTIGUOUS') #doctest: +SKIP
|
| 29 |
+
|
| 30 |
+
Our C-function typically takes an array and updates its values
|
| 31 |
+
in-place. For example::
|
| 32 |
+
|
| 33 |
+
void foo_func(double* x, int length)
|
| 34 |
+
{
|
| 35 |
+
int i;
|
| 36 |
+
for (i = 0; i < length; i++) {
|
| 37 |
+
x[i] = i*i;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
We wrap it using:
|
| 42 |
+
|
| 43 |
+
>>> _lib.foo_func.restype = None #doctest: +SKIP
|
| 44 |
+
>>> _lib.foo_func.argtypes = [array_1d_double, c_int] #doctest: +SKIP
|
| 45 |
+
|
| 46 |
+
Then, we're ready to call ``foo_func``:
|
| 47 |
+
|
| 48 |
+
>>> out = np.empty(15, dtype=np.double)
|
| 49 |
+
>>> _lib.foo_func(out, len(out)) #doctest: +SKIP
|
| 50 |
+
|
| 51 |
+
"""
|
| 52 |
+
__all__ = ['load_library', 'ndpointer', 'c_intp', 'as_ctypes', 'as_array',
|
| 53 |
+
'as_ctypes_type']
|
| 54 |
+
|
| 55 |
+
import os
|
| 56 |
+
import numpy as np
|
| 57 |
+
from numpy._core.multiarray import _flagdict, flagsobj
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
import ctypes
|
| 61 |
+
except ImportError:
|
| 62 |
+
ctypes = None
|
| 63 |
+
|
| 64 |
+
if ctypes is None:
|
| 65 |
+
def _dummy(*args, **kwds):
|
| 66 |
+
"""
|
| 67 |
+
Dummy object that raises an ImportError if ctypes is not available.
|
| 68 |
+
|
| 69 |
+
Raises
|
| 70 |
+
------
|
| 71 |
+
ImportError
|
| 72 |
+
If ctypes is not available.
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
raise ImportError("ctypes is not available.")
|
| 76 |
+
load_library = _dummy
|
| 77 |
+
as_ctypes = _dummy
|
| 78 |
+
as_array = _dummy
|
| 79 |
+
from numpy import intp as c_intp
|
| 80 |
+
_ndptr_base = object
|
| 81 |
+
else:
|
| 82 |
+
import numpy._core._internal as nic
|
| 83 |
+
c_intp = nic._getintp_ctype()
|
| 84 |
+
del nic
|
| 85 |
+
_ndptr_base = ctypes.c_void_p
|
| 86 |
+
|
| 87 |
+
# Adapted from Albert Strasheim
|
| 88 |
+
def load_library(libname, loader_path):
|
| 89 |
+
"""
|
| 90 |
+
It is possible to load a library using
|
| 91 |
+
|
| 92 |
+
>>> lib = ctypes.cdll[<full_path_name>] # doctest: +SKIP
|
| 93 |
+
|
| 94 |
+
But there are cross-platform considerations, such as library file extensions,
|
| 95 |
+
plus the fact Windows will just load the first library it finds with that name.
|
| 96 |
+
NumPy supplies the load_library function as a convenience.
|
| 97 |
+
|
| 98 |
+
.. versionchanged:: 1.20.0
|
| 99 |
+
Allow libname and loader_path to take any
|
| 100 |
+
:term:`python:path-like object`.
|
| 101 |
+
|
| 102 |
+
Parameters
|
| 103 |
+
----------
|
| 104 |
+
libname : path-like
|
| 105 |
+
Name of the library, which can have 'lib' as a prefix,
|
| 106 |
+
but without an extension.
|
| 107 |
+
loader_path : path-like
|
| 108 |
+
Where the library can be found.
|
| 109 |
+
|
| 110 |
+
Returns
|
| 111 |
+
-------
|
| 112 |
+
ctypes.cdll[libpath] : library object
|
| 113 |
+
A ctypes library object
|
| 114 |
+
|
| 115 |
+
Raises
|
| 116 |
+
------
|
| 117 |
+
OSError
|
| 118 |
+
If there is no library with the expected extension, or the
|
| 119 |
+
library is defective and cannot be loaded.
|
| 120 |
+
"""
|
| 121 |
+
# Convert path-like objects into strings
|
| 122 |
+
libname = os.fsdecode(libname)
|
| 123 |
+
loader_path = os.fsdecode(loader_path)
|
| 124 |
+
|
| 125 |
+
ext = os.path.splitext(libname)[1]
|
| 126 |
+
if not ext:
|
| 127 |
+
import sys
|
| 128 |
+
import sysconfig
|
| 129 |
+
# Try to load library with platform-specific name, otherwise
|
| 130 |
+
# default to libname.[so|dll|dylib]. Sometimes, these files are
|
| 131 |
+
# built erroneously on non-linux platforms.
|
| 132 |
+
base_ext = ".so"
|
| 133 |
+
if sys.platform.startswith("darwin"):
|
| 134 |
+
base_ext = ".dylib"
|
| 135 |
+
elif sys.platform.startswith("win"):
|
| 136 |
+
base_ext = ".dll"
|
| 137 |
+
libname_ext = [libname + base_ext]
|
| 138 |
+
so_ext = sysconfig.get_config_var("EXT_SUFFIX")
|
| 139 |
+
if not so_ext == base_ext:
|
| 140 |
+
libname_ext.insert(0, libname + so_ext)
|
| 141 |
+
else:
|
| 142 |
+
libname_ext = [libname]
|
| 143 |
+
|
| 144 |
+
loader_path = os.path.abspath(loader_path)
|
| 145 |
+
if not os.path.isdir(loader_path):
|
| 146 |
+
libdir = os.path.dirname(loader_path)
|
| 147 |
+
else:
|
| 148 |
+
libdir = loader_path
|
| 149 |
+
|
| 150 |
+
for ln in libname_ext:
|
| 151 |
+
libpath = os.path.join(libdir, ln)
|
| 152 |
+
if os.path.exists(libpath):
|
| 153 |
+
try:
|
| 154 |
+
return ctypes.cdll[libpath]
|
| 155 |
+
except OSError:
|
| 156 |
+
## defective lib file
|
| 157 |
+
raise
|
| 158 |
+
## if no successful return in the libname_ext loop:
|
| 159 |
+
raise OSError("no file with expected extension")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _num_fromflags(flaglist):
|
| 163 |
+
num = 0
|
| 164 |
+
for val in flaglist:
|
| 165 |
+
num += _flagdict[val]
|
| 166 |
+
return num
|
| 167 |
+
|
| 168 |
+
_flagnames = ['C_CONTIGUOUS', 'F_CONTIGUOUS', 'ALIGNED', 'WRITEABLE',
|
| 169 |
+
'OWNDATA', 'WRITEBACKIFCOPY']
|
| 170 |
+
def _flags_fromnum(num):
|
| 171 |
+
res = []
|
| 172 |
+
for key in _flagnames:
|
| 173 |
+
value = _flagdict[key]
|
| 174 |
+
if (num & value):
|
| 175 |
+
res.append(key)
|
| 176 |
+
return res
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class _ndptr(_ndptr_base):
|
| 180 |
+
@classmethod
|
| 181 |
+
def from_param(cls, obj):
|
| 182 |
+
if not isinstance(obj, np.ndarray):
|
| 183 |
+
raise TypeError("argument must be an ndarray")
|
| 184 |
+
if cls._dtype_ is not None \
|
| 185 |
+
and obj.dtype != cls._dtype_:
|
| 186 |
+
raise TypeError("array must have data type %s" % cls._dtype_)
|
| 187 |
+
if cls._ndim_ is not None \
|
| 188 |
+
and obj.ndim != cls._ndim_:
|
| 189 |
+
raise TypeError("array must have %d dimension(s)" % cls._ndim_)
|
| 190 |
+
if cls._shape_ is not None \
|
| 191 |
+
and obj.shape != cls._shape_:
|
| 192 |
+
raise TypeError("array must have shape %s" % str(cls._shape_))
|
| 193 |
+
if cls._flags_ is not None \
|
| 194 |
+
and ((obj.flags.num & cls._flags_) != cls._flags_):
|
| 195 |
+
raise TypeError("array must have flags %s" %
|
| 196 |
+
_flags_fromnum(cls._flags_))
|
| 197 |
+
return obj.ctypes
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class _concrete_ndptr(_ndptr):
|
| 201 |
+
"""
|
| 202 |
+
Like _ndptr, but with `_shape_` and `_dtype_` specified.
|
| 203 |
+
|
| 204 |
+
Notably, this means the pointer has enough information to reconstruct
|
| 205 |
+
the array, which is not generally true.
|
| 206 |
+
"""
|
| 207 |
+
def _check_retval_(self):
|
| 208 |
+
"""
|
| 209 |
+
This method is called when this class is used as the .restype
|
| 210 |
+
attribute for a shared-library function, to automatically wrap the
|
| 211 |
+
pointer into an array.
|
| 212 |
+
"""
|
| 213 |
+
return self.contents
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def contents(self):
|
| 217 |
+
"""
|
| 218 |
+
Get an ndarray viewing the data pointed to by this pointer.
|
| 219 |
+
|
| 220 |
+
This mirrors the `contents` attribute of a normal ctypes pointer
|
| 221 |
+
"""
|
| 222 |
+
full_dtype = np.dtype((self._dtype_, self._shape_))
|
| 223 |
+
full_ctype = ctypes.c_char * full_dtype.itemsize
|
| 224 |
+
buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents
|
| 225 |
+
return np.frombuffer(buffer, dtype=full_dtype).squeeze(axis=0)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# Factory for an array-checking class with from_param defined for
|
| 229 |
+
# use with ctypes argtypes mechanism
|
| 230 |
+
_pointer_type_cache = {}
|
| 231 |
+
def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
|
| 232 |
+
"""
|
| 233 |
+
Array-checking restype/argtypes.
|
| 234 |
+
|
| 235 |
+
An ndpointer instance is used to describe an ndarray in restypes
|
| 236 |
+
and argtypes specifications. This approach is more flexible than
|
| 237 |
+
using, for example, ``POINTER(c_double)``, since several restrictions
|
| 238 |
+
can be specified, which are verified upon calling the ctypes function.
|
| 239 |
+
These include data type, number of dimensions, shape and flags. If a
|
| 240 |
+
given array does not satisfy the specified restrictions,
|
| 241 |
+
a ``TypeError`` is raised.
|
| 242 |
+
|
| 243 |
+
Parameters
|
| 244 |
+
----------
|
| 245 |
+
dtype : data-type, optional
|
| 246 |
+
Array data-type.
|
| 247 |
+
ndim : int, optional
|
| 248 |
+
Number of array dimensions.
|
| 249 |
+
shape : tuple of ints, optional
|
| 250 |
+
Array shape.
|
| 251 |
+
flags : str or tuple of str
|
| 252 |
+
Array flags; may be one or more of:
|
| 253 |
+
|
| 254 |
+
- C_CONTIGUOUS / C / CONTIGUOUS
|
| 255 |
+
- F_CONTIGUOUS / F / FORTRAN
|
| 256 |
+
- OWNDATA / O
|
| 257 |
+
- WRITEABLE / W
|
| 258 |
+
- ALIGNED / A
|
| 259 |
+
- WRITEBACKIFCOPY / X
|
| 260 |
+
|
| 261 |
+
Returns
|
| 262 |
+
-------
|
| 263 |
+
klass : ndpointer type object
|
| 264 |
+
A type object, which is an ``_ndtpr`` instance containing
|
| 265 |
+
dtype, ndim, shape and flags information.
|
| 266 |
+
|
| 267 |
+
Raises
|
| 268 |
+
------
|
| 269 |
+
TypeError
|
| 270 |
+
If a given array does not satisfy the specified restrictions.
|
| 271 |
+
|
| 272 |
+
Examples
|
| 273 |
+
--------
|
| 274 |
+
>>> clib.somefunc.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64,
|
| 275 |
+
... ndim=1,
|
| 276 |
+
... flags='C_CONTIGUOUS')]
|
| 277 |
+
... #doctest: +SKIP
|
| 278 |
+
>>> clib.somefunc(np.array([1, 2, 3], dtype=np.float64))
|
| 279 |
+
... #doctest: +SKIP
|
| 280 |
+
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
# normalize dtype to dtype | None
|
| 284 |
+
if dtype is not None:
|
| 285 |
+
dtype = np.dtype(dtype)
|
| 286 |
+
|
| 287 |
+
# normalize flags to int | None
|
| 288 |
+
num = None
|
| 289 |
+
if flags is not None:
|
| 290 |
+
if isinstance(flags, str):
|
| 291 |
+
flags = flags.split(',')
|
| 292 |
+
elif isinstance(flags, (int, np.integer)):
|
| 293 |
+
num = flags
|
| 294 |
+
flags = _flags_fromnum(num)
|
| 295 |
+
elif isinstance(flags, flagsobj):
|
| 296 |
+
num = flags.num
|
| 297 |
+
flags = _flags_fromnum(num)
|
| 298 |
+
if num is None:
|
| 299 |
+
try:
|
| 300 |
+
flags = [x.strip().upper() for x in flags]
|
| 301 |
+
except Exception as e:
|
| 302 |
+
raise TypeError("invalid flags specification") from e
|
| 303 |
+
num = _num_fromflags(flags)
|
| 304 |
+
|
| 305 |
+
# normalize shape to tuple | None
|
| 306 |
+
if shape is not None:
|
| 307 |
+
try:
|
| 308 |
+
shape = tuple(shape)
|
| 309 |
+
except TypeError:
|
| 310 |
+
# single integer -> 1-tuple
|
| 311 |
+
shape = (shape,)
|
| 312 |
+
|
| 313 |
+
cache_key = (dtype, ndim, shape, num)
|
| 314 |
+
|
| 315 |
+
try:
|
| 316 |
+
return _pointer_type_cache[cache_key]
|
| 317 |
+
except KeyError:
|
| 318 |
+
pass
|
| 319 |
+
|
| 320 |
+
# produce a name for the new type
|
| 321 |
+
if dtype is None:
|
| 322 |
+
name = 'any'
|
| 323 |
+
elif dtype.names is not None:
|
| 324 |
+
name = str(id(dtype))
|
| 325 |
+
else:
|
| 326 |
+
name = dtype.str
|
| 327 |
+
if ndim is not None:
|
| 328 |
+
name += "_%dd" % ndim
|
| 329 |
+
if shape is not None:
|
| 330 |
+
name += "_"+"x".join(str(x) for x in shape)
|
| 331 |
+
if flags is not None:
|
| 332 |
+
name += "_"+"_".join(flags)
|
| 333 |
+
|
| 334 |
+
if dtype is not None and shape is not None:
|
| 335 |
+
base = _concrete_ndptr
|
| 336 |
+
else:
|
| 337 |
+
base = _ndptr
|
| 338 |
+
|
| 339 |
+
klass = type("ndpointer_%s"%name, (base,),
|
| 340 |
+
{"_dtype_": dtype,
|
| 341 |
+
"_shape_" : shape,
|
| 342 |
+
"_ndim_" : ndim,
|
| 343 |
+
"_flags_" : num})
|
| 344 |
+
_pointer_type_cache[cache_key] = klass
|
| 345 |
+
return klass
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
if ctypes is not None:
|
| 349 |
+
def _ctype_ndarray(element_type, shape):
|
| 350 |
+
""" Create an ndarray of the given element type and shape """
|
| 351 |
+
for dim in shape[::-1]:
|
| 352 |
+
element_type = dim * element_type
|
| 353 |
+
# prevent the type name include np.ctypeslib
|
| 354 |
+
element_type.__module__ = None
|
| 355 |
+
return element_type
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _get_scalar_type_map():
|
| 359 |
+
"""
|
| 360 |
+
Return a dictionary mapping native endian scalar dtype to ctypes types
|
| 361 |
+
"""
|
| 362 |
+
ct = ctypes
|
| 363 |
+
simple_types = [
|
| 364 |
+
ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong,
|
| 365 |
+
ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong,
|
| 366 |
+
ct.c_float, ct.c_double,
|
| 367 |
+
ct.c_bool,
|
| 368 |
+
]
|
| 369 |
+
return {np.dtype(ctype): ctype for ctype in simple_types}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
_scalar_type_map = _get_scalar_type_map()
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _ctype_from_dtype_scalar(dtype):
|
| 376 |
+
# swapping twice ensure that `=` is promoted to <, >, or |
|
| 377 |
+
dtype_with_endian = dtype.newbyteorder('S').newbyteorder('S')
|
| 378 |
+
dtype_native = dtype.newbyteorder('=')
|
| 379 |
+
try:
|
| 380 |
+
ctype = _scalar_type_map[dtype_native]
|
| 381 |
+
except KeyError as e:
|
| 382 |
+
raise NotImplementedError(
|
| 383 |
+
"Converting {!r} to a ctypes type".format(dtype)
|
| 384 |
+
) from None
|
| 385 |
+
|
| 386 |
+
if dtype_with_endian.byteorder == '>':
|
| 387 |
+
ctype = ctype.__ctype_be__
|
| 388 |
+
elif dtype_with_endian.byteorder == '<':
|
| 389 |
+
ctype = ctype.__ctype_le__
|
| 390 |
+
|
| 391 |
+
return ctype
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def _ctype_from_dtype_subarray(dtype):
|
| 395 |
+
element_dtype, shape = dtype.subdtype
|
| 396 |
+
ctype = _ctype_from_dtype(element_dtype)
|
| 397 |
+
return _ctype_ndarray(ctype, shape)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _ctype_from_dtype_structured(dtype):
|
| 401 |
+
# extract offsets of each field
|
| 402 |
+
field_data = []
|
| 403 |
+
for name in dtype.names:
|
| 404 |
+
field_dtype, offset = dtype.fields[name][:2]
|
| 405 |
+
field_data.append((offset, name, _ctype_from_dtype(field_dtype)))
|
| 406 |
+
|
| 407 |
+
# ctypes doesn't care about field order
|
| 408 |
+
field_data = sorted(field_data, key=lambda f: f[0])
|
| 409 |
+
|
| 410 |
+
if len(field_data) > 1 and all(offset == 0 for offset, name, ctype in field_data):
|
| 411 |
+
# union, if multiple fields all at address 0
|
| 412 |
+
size = 0
|
| 413 |
+
_fields_ = []
|
| 414 |
+
for offset, name, ctype in field_data:
|
| 415 |
+
_fields_.append((name, ctype))
|
| 416 |
+
size = max(size, ctypes.sizeof(ctype))
|
| 417 |
+
|
| 418 |
+
# pad to the right size
|
| 419 |
+
if dtype.itemsize != size:
|
| 420 |
+
_fields_.append(('', ctypes.c_char * dtype.itemsize))
|
| 421 |
+
|
| 422 |
+
# we inserted manual padding, so always `_pack_`
|
| 423 |
+
return type('union', (ctypes.Union,), dict(
|
| 424 |
+
_fields_=_fields_,
|
| 425 |
+
_pack_=1,
|
| 426 |
+
__module__=None,
|
| 427 |
+
))
|
| 428 |
+
else:
|
| 429 |
+
last_offset = 0
|
| 430 |
+
_fields_ = []
|
| 431 |
+
for offset, name, ctype in field_data:
|
| 432 |
+
padding = offset - last_offset
|
| 433 |
+
if padding < 0:
|
| 434 |
+
raise NotImplementedError("Overlapping fields")
|
| 435 |
+
if padding > 0:
|
| 436 |
+
_fields_.append(('', ctypes.c_char * padding))
|
| 437 |
+
|
| 438 |
+
_fields_.append((name, ctype))
|
| 439 |
+
last_offset = offset + ctypes.sizeof(ctype)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
padding = dtype.itemsize - last_offset
|
| 443 |
+
if padding > 0:
|
| 444 |
+
_fields_.append(('', ctypes.c_char * padding))
|
| 445 |
+
|
| 446 |
+
# we inserted manual padding, so always `_pack_`
|
| 447 |
+
return type('struct', (ctypes.Structure,), dict(
|
| 448 |
+
_fields_=_fields_,
|
| 449 |
+
_pack_=1,
|
| 450 |
+
__module__=None,
|
| 451 |
+
))
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def _ctype_from_dtype(dtype):
|
| 455 |
+
if dtype.fields is not None:
|
| 456 |
+
return _ctype_from_dtype_structured(dtype)
|
| 457 |
+
elif dtype.subdtype is not None:
|
| 458 |
+
return _ctype_from_dtype_subarray(dtype)
|
| 459 |
+
else:
|
| 460 |
+
return _ctype_from_dtype_scalar(dtype)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def as_ctypes_type(dtype):
|
| 464 |
+
r"""
|
| 465 |
+
Convert a dtype into a ctypes type.
|
| 466 |
+
|
| 467 |
+
Parameters
|
| 468 |
+
----------
|
| 469 |
+
dtype : dtype
|
| 470 |
+
The dtype to convert
|
| 471 |
+
|
| 472 |
+
Returns
|
| 473 |
+
-------
|
| 474 |
+
ctype
|
| 475 |
+
A ctype scalar, union, array, or struct
|
| 476 |
+
|
| 477 |
+
Raises
|
| 478 |
+
------
|
| 479 |
+
NotImplementedError
|
| 480 |
+
If the conversion is not possible
|
| 481 |
+
|
| 482 |
+
Notes
|
| 483 |
+
-----
|
| 484 |
+
This function does not losslessly round-trip in either direction.
|
| 485 |
+
|
| 486 |
+
``np.dtype(as_ctypes_type(dt))`` will:
|
| 487 |
+
|
| 488 |
+
- insert padding fields
|
| 489 |
+
- reorder fields to be sorted by offset
|
| 490 |
+
- discard field titles
|
| 491 |
+
|
| 492 |
+
``as_ctypes_type(np.dtype(ctype))`` will:
|
| 493 |
+
|
| 494 |
+
- discard the class names of `ctypes.Structure`\ s and
|
| 495 |
+
`ctypes.Union`\ s
|
| 496 |
+
- convert single-element `ctypes.Union`\ s into single-element
|
| 497 |
+
`ctypes.Structure`\ s
|
| 498 |
+
- insert padding fields
|
| 499 |
+
|
| 500 |
+
Examples
|
| 501 |
+
--------
|
| 502 |
+
Converting a simple dtype:
|
| 503 |
+
|
| 504 |
+
>>> dt = np.dtype('int8')
|
| 505 |
+
>>> ctype = np.ctypeslib.as_ctypes_type(dt)
|
| 506 |
+
>>> ctype
|
| 507 |
+
<class 'ctypes.c_byte'>
|
| 508 |
+
|
| 509 |
+
Converting a structured dtype:
|
| 510 |
+
|
| 511 |
+
>>> dt = np.dtype([('x', 'i4'), ('y', 'f4')])
|
| 512 |
+
>>> ctype = np.ctypeslib.as_ctypes_type(dt)
|
| 513 |
+
>>> ctype
|
| 514 |
+
<class 'struct'>
|
| 515 |
+
|
| 516 |
+
"""
|
| 517 |
+
return _ctype_from_dtype(np.dtype(dtype))
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def as_array(obj, shape=None):
|
| 521 |
+
"""
|
| 522 |
+
Create a numpy array from a ctypes array or POINTER.
|
| 523 |
+
|
| 524 |
+
The numpy array shares the memory with the ctypes object.
|
| 525 |
+
|
| 526 |
+
The shape parameter must be given if converting from a ctypes POINTER.
|
| 527 |
+
The shape parameter is ignored if converting from a ctypes array
|
| 528 |
+
|
| 529 |
+
Examples
|
| 530 |
+
--------
|
| 531 |
+
Converting a ctypes integer array:
|
| 532 |
+
|
| 533 |
+
>>> import ctypes
|
| 534 |
+
>>> ctypes_array = (ctypes.c_int * 5)(0, 1, 2, 3, 4)
|
| 535 |
+
>>> np_array = np.ctypeslib.as_array(ctypes_array)
|
| 536 |
+
>>> np_array
|
| 537 |
+
array([0, 1, 2, 3, 4], dtype=int32)
|
| 538 |
+
|
| 539 |
+
Converting a ctypes POINTER:
|
| 540 |
+
|
| 541 |
+
>>> import ctypes
|
| 542 |
+
>>> buffer = (ctypes.c_int * 5)(0, 1, 2, 3, 4)
|
| 543 |
+
>>> pointer = ctypes.cast(buffer, ctypes.POINTER(ctypes.c_int))
|
| 544 |
+
>>> np_array = np.ctypeslib.as_array(pointer, (5,))
|
| 545 |
+
>>> np_array
|
| 546 |
+
array([0, 1, 2, 3, 4], dtype=int32)
|
| 547 |
+
|
| 548 |
+
"""
|
| 549 |
+
if isinstance(obj, ctypes._Pointer):
|
| 550 |
+
# convert pointers to an array of the desired shape
|
| 551 |
+
if shape is None:
|
| 552 |
+
raise TypeError(
|
| 553 |
+
'as_array() requires a shape argument when called on a '
|
| 554 |
+
'pointer')
|
| 555 |
+
p_arr_type = ctypes.POINTER(_ctype_ndarray(obj._type_, shape))
|
| 556 |
+
obj = ctypes.cast(obj, p_arr_type).contents
|
| 557 |
+
|
| 558 |
+
return np.asarray(obj)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def as_ctypes(obj):
|
| 562 |
+
"""
|
| 563 |
+
Create and return a ctypes object from a numpy array. Actually
|
| 564 |
+
anything that exposes the __array_interface__ is accepted.
|
| 565 |
+
|
| 566 |
+
Examples
|
| 567 |
+
--------
|
| 568 |
+
Create ctypes object from inferred int ``np.array``:
|
| 569 |
+
|
| 570 |
+
>>> inferred_int_array = np.array([1, 2, 3])
|
| 571 |
+
>>> c_int_array = np.ctypeslib.as_ctypes(inferred_int_array)
|
| 572 |
+
>>> type(c_int_array)
|
| 573 |
+
<class 'c_long_Array_3'>
|
| 574 |
+
>>> c_int_array[:]
|
| 575 |
+
[1, 2, 3]
|
| 576 |
+
|
| 577 |
+
Create ctypes object from explicit 8 bit unsigned int ``np.array`` :
|
| 578 |
+
|
| 579 |
+
>>> exp_int_array = np.array([1, 2, 3], dtype=np.uint8)
|
| 580 |
+
>>> c_int_array = np.ctypeslib.as_ctypes(exp_int_array)
|
| 581 |
+
>>> type(c_int_array)
|
| 582 |
+
<class 'c_ubyte_Array_3'>
|
| 583 |
+
>>> c_int_array[:]
|
| 584 |
+
[1, 2, 3]
|
| 585 |
+
|
| 586 |
+
"""
|
| 587 |
+
ai = obj.__array_interface__
|
| 588 |
+
if ai["strides"]:
|
| 589 |
+
raise TypeError("strided arrays not supported")
|
| 590 |
+
if ai["version"] != 3:
|
| 591 |
+
raise TypeError("only __array_interface__ version 3 supported")
|
| 592 |
+
addr, readonly = ai["data"]
|
| 593 |
+
if readonly:
|
| 594 |
+
raise TypeError("readonly arrays unsupported")
|
| 595 |
+
|
| 596 |
+
# can't use `_dtype((ai["typestr"], ai["shape"]))` here, as it overflows
|
| 597 |
+
# dtype.itemsize (gh-14214)
|
| 598 |
+
ctype_scalar = as_ctypes_type(ai["typestr"])
|
| 599 |
+
result_type = _ctype_ndarray(ctype_scalar, ai["shape"])
|
| 600 |
+
result = result_type.from_address(addr)
|
| 601 |
+
result.__keep = obj
|
| 602 |
+
return result
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Device.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
|
| 6 |
+
#include <c10/core/Stream.h>
|
| 7 |
+
#include <c10/util/Registry.h>
|
| 8 |
+
|
| 9 |
+
#include <ATen/detail/AcceleratorHooksInterface.h>
|
| 10 |
+
|
| 11 |
+
#include <string>
|
| 12 |
+
|
| 13 |
+
namespace at {
|
| 14 |
+
class Context;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
namespace at {
|
| 18 |
+
|
| 19 |
+
constexpr const char* MTIA_HELP =
|
| 20 |
+
"The MTIA backend requires MTIA extension for PyTorch;"
|
| 21 |
+
"this error has occurred because you are trying "
|
| 22 |
+
"to use some MTIA's functionality without MTIA extension included.";
|
| 23 |
+
|
| 24 |
+
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
| 25 |
+
// this fails the implementation if MTIAHooks functions are called, but
|
| 26 |
+
// MTIA backend is not present.
|
| 27 |
+
#define FAIL_MTIAHOOKS_FUNC(func) \
|
| 28 |
+
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
|
| 29 |
+
|
| 30 |
+
~MTIAHooksInterface() override = default;
|
| 31 |
+
|
| 32 |
+
virtual void initMTIA() const {
|
| 33 |
+
// Avoid logging here, since MTIA needs init devices first then it will know
|
| 34 |
+
// how many devices are available. Make it as no-op if mtia extension is not
|
| 35 |
+
// dynamically loaded.
|
| 36 |
+
return;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
virtual bool hasMTIA() const {
|
| 40 |
+
return false;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
DeviceIndex deviceCount() const override {
|
| 44 |
+
return 0;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
|
| 48 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
virtual std::string showConfig() const {
|
| 52 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
| 56 |
+
return false;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
void setCurrentDevice(DeviceIndex device) const override {
|
| 60 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
DeviceIndex getCurrentDevice() const override {
|
| 64 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 65 |
+
return -1;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
DeviceIndex exchangeDevice(DeviceIndex device) const override {
|
| 69 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 70 |
+
return -1;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
|
| 74 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 75 |
+
return -1;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
virtual c10::Stream getCurrentStream(DeviceIndex device) const {
|
| 79 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 80 |
+
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
virtual c10::Stream getDefaultStream(DeviceIndex device) const {
|
| 84 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 85 |
+
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
virtual void setCurrentStream(const c10::Stream& stream) const {
|
| 89 |
+
FAIL_MTIAHOOKS_FUNC(__func__);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
struct TORCH_API MTIAHooksArgs {};
|
| 94 |
+
|
| 95 |
+
C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
|
| 96 |
+
#define REGISTER_MTIA_HOOKS(clsname) \
|
| 97 |
+
C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
|
| 98 |
+
|
| 99 |
+
namespace detail {
|
| 100 |
+
TORCH_API const MTIAHooksInterface& getMTIAHooks();
|
| 101 |
+
TORCH_API bool isMTIAHooksBuilt();
|
| 102 |
+
} // namespace detail
|
| 103 |
+
} // namespace at
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/core/Generator.h>
|
| 4 |
+
#include <ATen/detail/AcceleratorHooksInterface.h>
|
| 5 |
+
#include <c10/core/Allocator.h>
|
| 6 |
+
#include <c10/core/Device.h>
|
| 7 |
+
#include <c10/core/Storage.h>
|
| 8 |
+
#include <c10/util/Exception.h>
|
| 9 |
+
namespace at {
|
| 10 |
+
|
| 11 |
+
struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
|
| 12 |
+
~PrivateUse1HooksInterface() override = default;
|
| 13 |
+
virtual const at::Generator& getDefaultGenerator(
|
| 14 |
+
c10::DeviceIndex device_index) {
|
| 15 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
| 16 |
+
false,
|
| 17 |
+
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
virtual at::Device getDeviceFromPtr(void* data) const {
|
| 21 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
| 22 |
+
false,
|
| 23 |
+
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
virtual Allocator* getPinnedMemoryAllocator() const {
|
| 27 |
+
TORCH_CHECK(
|
| 28 |
+
false,
|
| 29 |
+
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
| 33 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
| 34 |
+
false,
|
| 35 |
+
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
virtual void initPrivateUse1() const {}
|
| 39 |
+
virtual void resizePrivateUse1Bytes(const c10::Storage &storage, size_t newsize) const {
|
| 40 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
| 41 |
+
false,
|
| 42 |
+
"You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
|
| 43 |
+
}
|
| 44 |
+
};
|
| 45 |
+
|
| 46 |
+
struct TORCH_API PrivateUse1HooksArgs {};
|
| 47 |
+
|
| 48 |
+
TORCH_API void RegisterPrivateUse1HooksInterface(
|
| 49 |
+
at::PrivateUse1HooksInterface* hook_);
|
| 50 |
+
|
| 51 |
+
TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
|
| 52 |
+
|
| 53 |
+
TORCH_API bool isPrivateUse1HooksRegistered();
|
| 54 |
+
|
| 55 |
+
namespace detail {
|
| 56 |
+
|
| 57 |
+
TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
|
| 58 |
+
|
| 59 |
+
} // namespace detail
|
| 60 |
+
|
| 61 |
+
} // namespace at
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <c10/core/Device.h>
|
| 4 |
+
#include <c10/util/Exception.h>
|
| 5 |
+
#include <ATen/core/Generator.h>
|
| 6 |
+
#include <c10/util/Registry.h>
|
| 7 |
+
|
| 8 |
+
namespace at {
|
| 9 |
+
|
| 10 |
+
constexpr const char* XPU_HELP =
|
| 11 |
+
"The XPU backend requires Intel Extension for Pytorch;"
|
| 12 |
+
"this error has occurred because you are trying "
|
| 13 |
+
"to use some XPU's functionality, but the Intel Extension for Pytorch has not been "
|
| 14 |
+
"loaded for some reason. The Intel Extension for Pytorch MUST "
|
| 15 |
+
"be loaded, EVEN IF you don't directly use any symbols from that!";
|
| 16 |
+
|
| 17 |
+
struct TORCH_API XPUHooksInterface {
|
| 18 |
+
virtual ~XPUHooksInterface() = default;
|
| 19 |
+
|
| 20 |
+
virtual void initXPU() const {
|
| 21 |
+
TORCH_CHECK(
|
| 22 |
+
false,
|
| 23 |
+
"Cannot initialize XPU without Intel Extension for Pytorch.",
|
| 24 |
+
XPU_HELP);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
virtual bool hasXPU() const {
|
| 28 |
+
return false;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
virtual std::string showConfig() const {
|
| 32 |
+
TORCH_CHECK(
|
| 33 |
+
false,
|
| 34 |
+
"Cannot query detailed XPU version without Intel Extension for Pytorch. ",
|
| 35 |
+
XPU_HELP);
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
virtual int32_t getGlobalIdxFromDevice(const Device& device) const {
|
| 39 |
+
TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
virtual Generator getXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
|
| 43 |
+
TORCH_CHECK(false, "Cannot get XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
virtual const Generator& getDefaultXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
|
| 47 |
+
TORCH_CHECK(false, "Cannot get default XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
virtual DeviceIndex getNumGPUs() const {
|
| 51 |
+
return 0;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
virtual DeviceIndex current_device() const {
|
| 55 |
+
TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
virtual Device getDeviceFromPtr(void* /*data*/) const {
|
| 59 |
+
TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
|
| 63 |
+
TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library.");
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
virtual Allocator* getPinnedMemoryAllocator() const {
|
| 67 |
+
TORCH_CHECK(false, "Cannot get XPU pinned memory allocator without ATen_xpu library.");
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
virtual bool isPinnedPtr(const void* /*data*/) const {
|
| 71 |
+
return false;
|
| 72 |
+
}
|
| 73 |
+
};
|
| 74 |
+
|
| 75 |
+
struct TORCH_API XPUHooksArgs {};
|
| 76 |
+
|
| 77 |
+
C10_DECLARE_REGISTRY(XPUHooksRegistry, XPUHooksInterface, XPUHooksArgs);
|
| 78 |
+
#define REGISTER_XPU_HOOKS(clsname) \
|
| 79 |
+
C10_REGISTER_CLASS(XPUHooksRegistry, clsname, clsname)
|
| 80 |
+
|
| 81 |
+
namespace detail {
|
| 82 |
+
TORCH_API const XPUHooksInterface& getXPUHooks();
|
| 83 |
+
} // namespace detail
|
| 84 |
+
} // namespace at
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/functorch/Interpreter.h>
|
| 3 |
+
|
| 4 |
+
namespace at::functorch {
|
| 5 |
+
|
| 6 |
+
// These are the interpreters for our AD transforms
|
| 7 |
+
// (grad, vjp and jvp).
|
| 8 |
+
// See NOTE: [functorch interpreter stack] for more details.
|
| 9 |
+
|
| 10 |
+
struct TORCH_API GradInterpreterPtr {
|
| 11 |
+
explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
|
| 12 |
+
TransformType key() const { return base_->key(); }
|
| 13 |
+
int64_t level() const { return base_->level(); }
|
| 14 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 15 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 16 |
+
bool prevGradMode() const {
|
| 17 |
+
return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
|
| 18 |
+
}
|
| 19 |
+
Tensor lift(const Tensor& tensor) const;
|
| 20 |
+
private:
|
| 21 |
+
const Interpreter* base_;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
struct TORCH_API JvpInterpreterPtr {
|
| 25 |
+
explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
|
| 26 |
+
TransformType key() const { return base_->key(); }
|
| 27 |
+
int64_t level() const { return base_->level(); }
|
| 28 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 29 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 30 |
+
bool prevFwdGradMode() const {
|
| 31 |
+
return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
|
| 32 |
+
}
|
| 33 |
+
Tensor lift(const Tensor& tensor) const;
|
| 34 |
+
private:
|
| 35 |
+
const Interpreter* base_;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the BSD-style license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
#pragma once
|
| 7 |
+
|
| 8 |
+
#include <c10/util/TypeList.h>
|
| 9 |
+
|
| 10 |
+
#include <ATen/ATen.h>
|
| 11 |
+
#include <ATen/Operators.h>
|
| 12 |
+
|
| 13 |
+
#include <ATen/functorch/DynamicLayer.h>
|
| 14 |
+
#include <ATen/functorch/TensorWrapper.h>
|
| 15 |
+
#include <ATen/functorch/BatchingMetaprogramming.h>
|
| 16 |
+
#include <ATen/functorch/LegacyVmapTransforms.h>
|
| 17 |
+
#include <ATen/functorch/BatchedFallback.h>
|
| 18 |
+
#include <ATen/functorch/PlumbingHelper.h>
|
| 19 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 20 |
+
#include <ATen/VmapGeneratedPlumbing.h>
|
| 21 |
+
|
| 22 |
+
#include <utility>
|
| 23 |
+
|
| 24 |
+
// This file contains helper functions for batching rules.
|
| 25 |
+
|
| 26 |
+
namespace at::functorch {
|
| 27 |
+
|
| 28 |
+
TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
|
| 29 |
+
TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
|
| 30 |
+
|
| 31 |
+
TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
|
| 32 |
+
|
| 33 |
+
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
|
| 34 |
+
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
|
| 35 |
+
int64_t numelWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
|
| 36 |
+
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
|
| 37 |
+
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
|
| 38 |
+
VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
|
| 39 |
+
|
| 40 |
+
void vmapIncompatibleInplaceError(const char* schema_name);
|
| 41 |
+
|
| 42 |
+
Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank);
|
| 43 |
+
|
| 44 |
+
void check_randomness(RandomnessType randomness);
|
| 45 |
+
void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
|
| 46 |
+
|
| 47 |
+
inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
|
| 48 |
+
if (has_bdim) {
|
| 49 |
+
return tensor;
|
| 50 |
+
}
|
| 51 |
+
const auto sizes = tensor.sym_sizes();
|
| 52 |
+
SymDimVector expanded_shape;
|
| 53 |
+
expanded_shape.reserve(sizes.size());
|
| 54 |
+
expanded_shape.emplace_back(std::move(batch_size));
|
| 55 |
+
expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
|
| 56 |
+
return tensor.expand_symint(expanded_shape);
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
#define VMAP_SUPPORT(op, batch_rule) \
|
| 60 |
+
m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
|
| 61 |
+
|
| 62 |
+
#define VMAP_SUPPORT2(op, overload, batch_rule) \
|
| 63 |
+
m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
|
| 64 |
+
|
| 65 |
+
#define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
|
| 66 |
+
#define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
|
| 67 |
+
|
| 68 |
+
// DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
|
| 69 |
+
template <typename A, A a, typename C>
|
| 70 |
+
struct BasicUnaryBatchRuleHelper;
|
| 71 |
+
|
| 72 |
+
template <typename F, F Func, typename A, typename... T>
|
| 73 |
+
struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
| 74 |
+
static std::tuple<Tensor,optional<int64_t>> apply(
|
| 75 |
+
const Tensor& tensor,
|
| 76 |
+
optional<int64_t> batch_dim,
|
| 77 |
+
T... extra_args) {
|
| 78 |
+
return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
|
| 79 |
+
}
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
// USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
|
| 83 |
+
// INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
|
| 84 |
+
// It is important that this macro is not passed a function pointer!!
|
| 85 |
+
#define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
|
| 86 |
+
BasicUnaryBatchRuleHelper<\
|
| 87 |
+
decltype(&fn),\
|
| 88 |
+
&fn,\
|
| 89 |
+
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
| 90 |
+
|
| 91 |
+
#define UNARY_POINTWISE(op) \
|
| 92 |
+
VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
|
| 93 |
+
|
| 94 |
+
template <typename A, A a, typename C>
|
| 95 |
+
struct VariadicBdimsBatchRuleHelper;
|
| 96 |
+
|
| 97 |
+
template <typename F, F Func, typename A, typename... T>
|
| 98 |
+
struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
| 99 |
+
static std::tuple<Tensor,optional<int64_t>> apply(
|
| 100 |
+
const Tensor& tensor,
|
| 101 |
+
optional<int64_t> batch_dim,
|
| 102 |
+
T... extra_args) {
|
| 103 |
+
auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
|
| 104 |
+
return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
|
| 105 |
+
}
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
// USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
|
| 109 |
+
// INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
|
| 110 |
+
// It is important that this macro is not passed a function pointer!!
|
| 111 |
+
#define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
|
| 112 |
+
VariadicBdimsBatchRuleHelper<\
|
| 113 |
+
decltype(&fn),\
|
| 114 |
+
&fn,\
|
| 115 |
+
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
| 116 |
+
|
| 117 |
+
#define VARIADIC_BDIMS(op) \
|
| 118 |
+
VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
|
| 119 |
+
|
| 120 |
+
#define VARIADIC_BDIMS2(op, overload) \
|
| 121 |
+
VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
|
| 122 |
+
|
| 123 |
+
template<class F, F Func>
|
| 124 |
+
void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
| 125 |
+
const auto& schema = op.schema();
|
| 126 |
+
const auto num_returns = schema.returns().size();
|
| 127 |
+
const auto num_arguments = schema.arguments().size();
|
| 128 |
+
|
| 129 |
+
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
| 130 |
+
auto maybe_layer = maybeCurrentDynamicLayer();
|
| 131 |
+
vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
|
| 132 |
+
|
| 133 |
+
int64_t cur_level = maybe_layer->layerId();
|
| 134 |
+
|
| 135 |
+
auto orig_arguments = torch::jit::last(*stack, num_arguments);
|
| 136 |
+
if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
| 137 |
+
op.callBoxed(stack);
|
| 138 |
+
return;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
auto arguments = torch::jit::pop(*stack, num_arguments);
|
| 142 |
+
std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
|
| 143 |
+
std::vector<int64_t> tensor_pos;
|
| 144 |
+
for (const auto idx : c10::irange(0, num_arguments)) {
|
| 145 |
+
const auto& ivalue = arguments[idx];
|
| 146 |
+
if (ivalue.isTensor()) {
|
| 147 |
+
auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
|
| 148 |
+
tensor_inputs.emplace_back(tensor_value, tensor_bdim);
|
| 149 |
+
tensor_pos.push_back(static_cast<int64_t>(idx));
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
Func(tensor_inputs);
|
| 153 |
+
|
| 154 |
+
size_t tensor_idx = 0;
|
| 155 |
+
TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
|
| 156 |
+
for (const auto arg_idx : c10::irange(0, num_arguments)) {
|
| 157 |
+
if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
|
| 158 |
+
torch::jit::push(stack, arguments[arg_idx]);
|
| 159 |
+
} else {
|
| 160 |
+
TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
|
| 161 |
+
torch::jit::push(stack, tensor_inputs[tensor_idx].first);
|
| 162 |
+
tensor_idx++;
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
op.callBoxed(stack);
|
| 167 |
+
const auto returns = torch::jit::pop(*stack, num_returns);
|
| 168 |
+
for (const auto& ret : returns) {
|
| 169 |
+
if (ret.isTensor()) {
|
| 170 |
+
torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
|
| 171 |
+
} else {
|
| 172 |
+
TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
inline void handle_pointwise_ops(std::vector<std::pair<Tensor, optional<int64_t>>> &tensor_inputs) {
|
| 178 |
+
int64_t out_logical_rank = 0;
|
| 179 |
+
for (auto& tensor_input : tensor_inputs) {
|
| 180 |
+
int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
|
| 181 |
+
out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
|
| 182 |
+
}
|
| 183 |
+
for (auto& tensor_input: tensor_inputs) {
|
| 184 |
+
tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
|
| 185 |
+
tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
#define POINTWISE_BOXED(op) \
|
| 190 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
|
| 191 |
+
|
| 192 |
+
#define POINTWISE_BOXED2(op, overload) \
|
| 193 |
+
m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
|
| 194 |
+
|
| 195 |
+
inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t>>> &tensor_inputs) {
|
| 196 |
+
for (auto & tensor_input : tensor_inputs) {
|
| 197 |
+
tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
#define VARIADIC_BDIMS_BOXED(op) \
|
| 202 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
|
| 203 |
+
|
| 204 |
+
using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
|
| 205 |
+
|
| 206 |
+
inline void find_and_unpack_tensors(
|
| 207 |
+
const torch::jit::Stack* stack,
|
| 208 |
+
int64_t num_args,
|
| 209 |
+
int64_t cur_level,
|
| 210 |
+
SmallVector<UnpackedBatchedTensor, 5>* tensors,
|
| 211 |
+
SmallVector<int64_t, 5>* tensors_pos,
|
| 212 |
+
int64_t* batch_size) {
|
| 213 |
+
|
| 214 |
+
int64_t computed_batch_size = -1;
|
| 215 |
+
int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
|
| 216 |
+
|
| 217 |
+
for (const auto idx : c10::irange(0, num_args)) {
|
| 218 |
+
const auto& ivalue = (*stack)[args_begin + idx];
|
| 219 |
+
if (!ivalue.isTensor()) {
|
| 220 |
+
continue;
|
| 221 |
+
}
|
| 222 |
+
auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
|
| 223 |
+
const auto& tensor_value = std::get<0>(unpacked);
|
| 224 |
+
const auto tensor_bdim = std::get<1>(unpacked);
|
| 225 |
+
if (tensor_bdim.has_value()) {
|
| 226 |
+
auto candidate_batch_size = tensor_value.size(*tensor_bdim);
|
| 227 |
+
if (computed_batch_size == -1) {
|
| 228 |
+
computed_batch_size = candidate_batch_size;
|
| 229 |
+
}
|
| 230 |
+
TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
tensors->push_back(std::move(unpacked));
|
| 234 |
+
tensors_pos->push_back(idx);
|
| 235 |
+
}
|
| 236 |
+
TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
|
| 237 |
+
*batch_size = computed_batch_size;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
inline void boxed_existing_bdim_all_batch_rule(
|
| 241 |
+
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
| 242 |
+
const auto& schema = op.schema();
|
| 243 |
+
const auto num_returns = schema.returns().size();
|
| 244 |
+
const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
|
| 245 |
+
|
| 246 |
+
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
| 247 |
+
auto maybe_layer = maybeCurrentDynamicLayer();
|
| 248 |
+
vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
|
| 249 |
+
int64_t cur_level = maybe_layer->layerId();
|
| 250 |
+
|
| 251 |
+
const auto arguments = torch::jit::last(stack, num_arguments);
|
| 252 |
+
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
| 253 |
+
op.callBoxed(stack);
|
| 254 |
+
return;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
|
| 258 |
+
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
|
| 259 |
+
SmallVector<int64_t, 5> tensor_pos;
|
| 260 |
+
int64_t batch_size = 0;
|
| 261 |
+
|
| 262 |
+
find_and_unpack_tensors(
|
| 263 |
+
stack, num_arguments, cur_level,
|
| 264 |
+
&tensor_inputs, &tensor_pos, &batch_size);
|
| 265 |
+
|
| 266 |
+
// for each tensor, ensure it has a bdim and reshape it.
|
| 267 |
+
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
|
| 268 |
+
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
|
| 269 |
+
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
|
| 270 |
+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
|
| 271 |
+
if (!bdim.has_value()) {
|
| 272 |
+
bdim = 0;
|
| 273 |
+
}
|
| 274 |
+
(*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
op.callBoxed(stack);
|
| 278 |
+
|
| 279 |
+
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
|
| 280 |
+
const auto& ret = (*stack)[idx];
|
| 281 |
+
TORCH_INTERNAL_ASSERT(ret.isTensor(),
|
| 282 |
+
"This boxed batching rule does not currently support ops that return non-tensor values");
|
| 283 |
+
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// Use when all tensors arguments accept one (normal) batch dim.
|
| 288 |
+
// This batching rule expands the batch dim on all Tensors, reshapes it into
|
| 289 |
+
// dim 0, calls the op, and then reshapes the batch dim out of dim 0.
|
| 290 |
+
// This is not the most efficient thing; if there are alternatives, plese try
|
| 291 |
+
// to use them. Use this only as a last resort.
|
| 292 |
+
#define EXISTING_BDIM_ALL_BOXED(op) \
|
| 293 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
|
| 294 |
+
|
| 295 |
+
template <int64_t feature_rank, int64_t contig_tensor_index=-1>
|
| 296 |
+
inline void boxed_all_tensors_have_optional_bdim(
|
| 297 |
+
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
| 298 |
+
const auto& schema = op.schema();
|
| 299 |
+
const auto num_returns = schema.returns().size();
|
| 300 |
+
const auto num_arguments = schema.arguments().size();
|
| 301 |
+
|
| 302 |
+
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
|
| 303 |
+
auto maybe_layer = maybeCurrentDynamicLayer();
|
| 304 |
+
vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
|
| 305 |
+
int64_t cur_level = maybe_layer->layerId();
|
| 306 |
+
|
| 307 |
+
const auto arguments = torch::jit::last(stack, num_arguments);
|
| 308 |
+
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
|
| 309 |
+
op.callBoxed(stack);
|
| 310 |
+
return;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
|
| 314 |
+
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
|
| 315 |
+
SmallVector<int64_t, 5> tensor_pos;
|
| 316 |
+
int64_t batch_size = 0;
|
| 317 |
+
|
| 318 |
+
find_and_unpack_tensors(
|
| 319 |
+
stack, static_cast<int64_t>(num_arguments), cur_level,
|
| 320 |
+
&tensor_inputs, &tensor_pos, &batch_size);
|
| 321 |
+
|
| 322 |
+
optional<bool> is_no_batch_dim_case;
|
| 323 |
+
|
| 324 |
+
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
|
| 325 |
+
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
|
| 326 |
+
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
|
| 327 |
+
const auto logical_rank = rankWithoutBatchDim(value, bdim);
|
| 328 |
+
|
| 329 |
+
if (!is_no_batch_dim_case.has_value()) {
|
| 330 |
+
is_no_batch_dim_case = (logical_rank == feature_rank);
|
| 331 |
+
}
|
| 332 |
+
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
|
| 333 |
+
if (!bdim.has_value()) {
|
| 334 |
+
bdim = 0;
|
| 335 |
+
}
|
| 336 |
+
if (*is_no_batch_dim_case) {
|
| 337 |
+
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
|
| 338 |
+
value_ = moveBatchDimToFront(value_, bdim);
|
| 339 |
+
if (tensor_idx == contig_tensor_index) {
|
| 340 |
+
value_ = value_.contiguous();
|
| 341 |
+
}
|
| 342 |
+
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
|
| 343 |
+
continue;
|
| 344 |
+
}
|
| 345 |
+
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
|
| 346 |
+
value_ = reshape_dim_into(*bdim, 0, value_);
|
| 347 |
+
if (tensor_idx == contig_tensor_index) {
|
| 348 |
+
value_ = value_.contiguous();
|
| 349 |
+
}
|
| 350 |
+
(*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
op.callBoxed(stack);
|
| 354 |
+
|
| 355 |
+
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
|
| 356 |
+
const auto& ret = (*stack)[idx];
|
| 357 |
+
TORCH_INTERNAL_ASSERT(ret.isTensor(),
|
| 358 |
+
"This boxed batching rule does not currently support ops that return non-tensor values");
|
| 359 |
+
if (*is_no_batch_dim_case) {
|
| 360 |
+
(*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
|
| 361 |
+
} else {
|
| 362 |
+
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
// Useful for many NN operators.
|
| 368 |
+
// The operator must satisfy the following:
|
| 369 |
+
// - All arguments must accept an optional batch dim.
|
| 370 |
+
// - All arguments must be the same rank
|
| 371 |
+
#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
|
| 372 |
+
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
|
| 373 |
+
|
| 374 |
+
#define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
|
| 375 |
+
m.impl(#op, \
|
| 376 |
+
torch::CppFunction::makeFromBoxedFunction<\
|
| 377 |
+
boxed_all_tensors_have_optional_bdim<\
|
| 378 |
+
feature_rank, \
|
| 379 |
+
contig_tensor_index>\
|
| 380 |
+
>());
|
| 381 |
+
|
| 382 |
+
template <typename A, A a, typename C>
|
| 383 |
+
struct ExistingBdimBatchRuleHelper;
|
| 384 |
+
|
| 385 |
+
template <typename F, F Func, typename A, typename... T>
|
| 386 |
+
struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
|
| 387 |
+
static std::tuple<Tensor,optional<int64_t>> apply(
|
| 388 |
+
const Tensor& self,
|
| 389 |
+
optional<int64_t> self_bdim,
|
| 390 |
+
T... extra_args) {
|
| 391 |
+
auto self_ = reshape_dim_into(*self_bdim, 0, self);
|
| 392 |
+
auto out = Func(self_, std::forward<T>(extra_args)...);
|
| 393 |
+
return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
|
| 394 |
+
}
|
| 395 |
+
};
|
| 396 |
+
|
| 397 |
+
// USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
|
| 398 |
+
// INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
|
| 399 |
+
// It is important that this macro is not passed a function pointer!!
|
| 400 |
+
#define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
|
| 401 |
+
ExistingBdimBatchRuleHelper<\
|
| 402 |
+
decltype(&fn),\
|
| 403 |
+
&fn,\
|
| 404 |
+
c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
#define EXISTING_BDIM(op) \
|
| 408 |
+
VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
|
| 409 |
+
|
| 410 |
+
#define EXISTING_BDIM2(op, overload) \
|
| 411 |
+
VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
|
| 412 |
+
|
| 413 |
+
#define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
template <typename F, F Method, typename... ExtraArgs>
|
| 417 |
+
Tensor& unary_inplace_batch_rule(Tensor& self, optional<int64_t>, ExtraArgs... extra_args) {
|
| 418 |
+
INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
|
| 419 |
+
return self;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
inline int64_t get_bdim_size4(
|
| 423 |
+
const Tensor& a_value, optional<int64_t> a_bdim,
|
| 424 |
+
const Tensor& b_value, optional<int64_t> b_bdim,
|
| 425 |
+
const Tensor& c_value, optional<int64_t> c_bdim,
|
| 426 |
+
const Tensor& d_value, optional<int64_t> d_bdim) {
|
| 427 |
+
if (a_bdim)
|
| 428 |
+
return a_value.size(*a_bdim);
|
| 429 |
+
if (b_bdim)
|
| 430 |
+
return b_value.size(*b_bdim);
|
| 431 |
+
if (c_bdim)
|
| 432 |
+
return c_value.size(*c_bdim);
|
| 433 |
+
if (d_bdim)
|
| 434 |
+
return d_value.size(*d_bdim);
|
| 435 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
inline int64_t get_bdim_size3(
|
| 439 |
+
const Tensor& a_value, optional<int64_t> a_bdim,
|
| 440 |
+
const Tensor& b_value, optional<int64_t> b_bdim,
|
| 441 |
+
const Tensor& c_value, optional<int64_t> c_bdim) {
|
| 442 |
+
if (a_bdim)
|
| 443 |
+
return a_value.size(*a_bdim);
|
| 444 |
+
if (b_bdim)
|
| 445 |
+
return b_value.size(*b_bdim);
|
| 446 |
+
if (c_bdim)
|
| 447 |
+
return c_value.size(*c_bdim);
|
| 448 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 449 |
+
}
|
| 450 |
+
|
| 451 |
+
inline int64_t get_bdim_size2(
|
| 452 |
+
const Tensor& a_value, optional<int64_t> a_bdim,
|
| 453 |
+
const Tensor& b_value, optional<int64_t> b_bdim) {
|
| 454 |
+
if (a_bdim)
|
| 455 |
+
return a_value.size(*a_bdim);
|
| 456 |
+
if (b_bdim)
|
| 457 |
+
return b_value.size(*b_bdim);
|
| 458 |
+
TORCH_INTERNAL_ASSERT(false);
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
// [start, start + 1, ..., stop - 1]
|
| 462 |
+
inline VmapDimVector range(int64_t start, int64_t stop) {
|
| 463 |
+
TORCH_INTERNAL_ASSERT(stop >= start);
|
| 464 |
+
VmapDimVector dims;
|
| 465 |
+
dims.reserve(stop - start);
|
| 466 |
+
for (int64_t i = start; i < stop; i++) {
|
| 467 |
+
dims.emplace_back(i);
|
| 468 |
+
}
|
| 469 |
+
return dims;
|
| 470 |
+
}
|
| 471 |
+
std::tuple<Tensor, Tensor> _binary_pointwise_helper(
|
| 472 |
+
const Tensor& tensor, optional<int64_t> tensor_batch_dim, const Tensor& other, optional<int64_t> other_batch_dim,
|
| 473 |
+
bool do_type_promotion=true);
|
| 474 |
+
|
| 475 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the BSD-style license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#pragma once
|
| 8 |
+
#include <ATen/ATen.h>
|
| 9 |
+
#include <ATen/core/op_registration/op_registration.h>
|
| 10 |
+
#include <torch/library.h>
|
| 11 |
+
|
| 12 |
+
namespace at::functorch {
|
| 13 |
+
|
| 14 |
+
// This file contains code for the vmap fallback (also known as the
|
| 15 |
+
// BatchedTensor fallback or the Batched fallback). This code runs
|
| 16 |
+
// when an operation doesn't have a batching rule implemented.
|
| 17 |
+
|
| 18 |
+
// If an operator doesn't have a batching rule implemented then we fallback
|
| 19 |
+
// to this implementation. The fallback doesn't work on out= variants or
|
| 20 |
+
// view operations; that is, it works for out-of-place operations and
|
| 21 |
+
// in-place non-view operations.
|
| 22 |
+
//
|
| 23 |
+
// For out-of-place operations, the fallback effectively takes all of the
|
| 24 |
+
// BatchedTensors in `stack`, slices them, and runs `op` on all of the
|
| 25 |
+
// corresponding slices to produce slices of the outputs. The output slices
|
| 26 |
+
// then get `torch.stack`ed to create the
|
| 27 |
+
// final returns.
|
| 28 |
+
//
|
| 29 |
+
// The performance of the fallback is not very good because it introduces an
|
| 30 |
+
// extra copy from stacking the sliced outputs. Because of this, we prefer to
|
| 31 |
+
// write batching rules for operators whenever possible.
|
| 32 |
+
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 33 |
+
void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 34 |
+
|
| 35 |
+
void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 36 |
+
|
| 37 |
+
// The vmap fallback emits a warning by default, but it may be disabled if
|
| 38 |
+
// the user finds it to be too annoying.
|
| 39 |
+
TORCH_API bool isVmapFallbackWarningEnabled();
|
| 40 |
+
TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
|
| 41 |
+
|
| 42 |
+
// Used for testing. The vmap fallback is enabled by default. When it is disabled,
|
| 43 |
+
// it raises an error.
|
| 44 |
+
TORCH_API bool isVmapFallbackEnabled();
|
| 45 |
+
TORCH_API void setVmapFallbackEnabled(bool enabled);
|
| 46 |
+
|
| 47 |
+
template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
|
| 48 |
+
return buffer[0].to<A>();
|
| 49 |
+
}
|
| 50 |
+
template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
|
| 51 |
+
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
|
| 52 |
+
}
|
| 53 |
+
template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
|
| 54 |
+
return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// slow_fallback is a way to call the vmap fallback inside some boxed kernel.
|
| 58 |
+
// There is probably some better way to metaprogram this.
|
| 59 |
+
template <typename Ret>
|
| 60 |
+
Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 61 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 62 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 63 |
+
return vector_to_result<Ret>(stack);
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
template <typename A, typename B>
|
| 67 |
+
std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 68 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 69 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 70 |
+
return vector_to_result<A, B>(stack);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
template <typename A, typename B, typename C>
|
| 74 |
+
std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
|
| 75 |
+
std::vector<IValue> stack(args.begin(), args.end());
|
| 76 |
+
batchedTensorForLoopFallback(op, &stack);
|
| 77 |
+
return vector_to_result<A, B, C>(stack);
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the BSD-style license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#pragma once
|
| 8 |
+
|
| 9 |
+
#include <bitset>
|
| 10 |
+
|
| 11 |
+
#include <ATen/ArrayRef.h>
|
| 12 |
+
#include <ATen/SmallVector.h>
|
| 13 |
+
#include <ATen/Tensor.h>
|
| 14 |
+
|
| 15 |
+
namespace at::functorch {
|
| 16 |
+
|
| 17 |
+
using Tensor = at::Tensor;
|
| 18 |
+
|
| 19 |
+
// We assume this in a few other places in the codebase,
|
| 20 |
+
// but there isn't a centralized definition.
|
| 21 |
+
constexpr int64_t kVmapMaxTensorDims = 64;
|
| 22 |
+
|
| 23 |
+
// The valid vmap levels range from [0, 64). This effectively means that we
|
| 24 |
+
// support a maximum of 64 nested vmaps.
|
| 25 |
+
constexpr int64_t kVmapNumLevels = 64;
|
| 26 |
+
|
| 27 |
+
// Store this number of elements of BatchDims on the stack. Most people will
|
| 28 |
+
// probably use <= 5 nested vmaps, but adjust this number as necessary.
|
| 29 |
+
constexpr int64_t kBatchDimsStackSize = 5;
|
| 30 |
+
|
| 31 |
+
// A BatchedTensorImpl holds an underlying Tensor and a single batch dim
|
| 32 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 33 |
+
// BatchedTensorImpl.
|
| 34 |
+
//
|
| 35 |
+
// The batch dimensions are treated as being "private"; they are not user-visible.
|
| 36 |
+
// For example, in the following Tensor,
|
| 37 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
|
| 38 |
+
// dimension 0 is batch dimension.
|
| 39 |
+
//
|
| 40 |
+
// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
|
| 41 |
+
// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
|
| 42 |
+
struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
|
| 43 |
+
explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
|
| 44 |
+
|
| 45 |
+
// Returns batch dimension of this tensor
|
| 46 |
+
int64_t bdim() const { return bdim_; }
|
| 47 |
+
|
| 48 |
+
// Returns batch dimension of this tensor
|
| 49 |
+
int64_t level() const { return level_; }
|
| 50 |
+
|
| 51 |
+
// BatchedTensorImpl wraps a Tensor
|
| 52 |
+
const Tensor& value() const { return value_; }
|
| 53 |
+
|
| 54 |
+
// Given a public dimension index, return the dimension index in the underlying
|
| 55 |
+
// value() tensor.
|
| 56 |
+
// For example, if we have
|
| 57 |
+
// bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
|
| 58 |
+
// bt.actualDim(0) -> 1
|
| 59 |
+
// bt.actualDim(1) -> 2
|
| 60 |
+
// bt.actualDim(2) -> 3
|
| 61 |
+
// bt.actualDim(3) -> Error
|
| 62 |
+
int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
|
| 63 |
+
|
| 64 |
+
IntArrayRef sizes_custom() const override;
|
| 65 |
+
SymIntArrayRef sym_sizes_custom() const override;
|
| 66 |
+
int64_t size_custom(int64_t d) const override;
|
| 67 |
+
c10::SymInt sym_size_custom(int64_t d) const override;
|
| 68 |
+
// We have to override this because we opted into CustomStrides
|
| 69 |
+
IntArrayRef strides_custom() const override;
|
| 70 |
+
SymIntArrayRef sym_strides_custom() const override;
|
| 71 |
+
// Override a bunch of methods inherited from TensorImpl to return error messages.
|
| 72 |
+
bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
|
| 73 |
+
void set_size(int64_t dim, int64_t new_size) override;
|
| 74 |
+
void set_stride(int64_t dim, int64_t new_stride) override;
|
| 75 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 76 |
+
const c10::VariableVersion& version_counter,
|
| 77 |
+
bool allow_tensor_metadata_change) const override;
|
| 78 |
+
c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
|
| 79 |
+
c10::VariableVersion&& version_counter,
|
| 80 |
+
bool allow_tensor_metadata_change) const override;
|
| 81 |
+
void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
|
| 82 |
+
#ifdef DEBUG
|
| 83 |
+
bool has_storage() const override;
|
| 84 |
+
#endif
|
| 85 |
+
|
| 86 |
+
void refreshTensorMetadata();
|
| 87 |
+
|
| 88 |
+
// Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
|
| 89 |
+
// accomplishes this is a hack where it is able to modify the levels of
|
| 90 |
+
// BatchedTensor to match the level of the current vmap transform.
|
| 91 |
+
void _unsafe_set_level(int64_t level) {
|
| 92 |
+
level_ = level;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// Used in batching rule for in-place view operations that can change
|
| 96 |
+
// the index of the bdim (think squeeze_, unsqueeze_)
|
| 97 |
+
void unsafe_set_bdim(int64_t bdim) {
|
| 98 |
+
// NB: you MUST call refreshTensorMetadata after doing this.
|
| 99 |
+
bdim_ = bdim;
|
| 100 |
+
}
|
| 101 |
+
private:
|
| 102 |
+
// see NOTE: [BatchedTensorImpl levels invariant]
|
| 103 |
+
void checkInvariants() const;
|
| 104 |
+
const char* tensorimpl_type_name() const override;
|
| 105 |
+
|
| 106 |
+
Tensor value_;
|
| 107 |
+
|
| 108 |
+
int64_t level_;
|
| 109 |
+
int64_t bdim_;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
|
| 113 |
+
// BatchedTensorImpl.
|
| 114 |
+
inline bool isBatchedTensor(const Tensor& tensor) {
|
| 115 |
+
return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
|
| 116 |
+
tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// It is unsafe to call this on a Tensor that is not backed by a
|
| 120 |
+
// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
|
| 121 |
+
inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
|
| 122 |
+
return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
|
| 126 |
+
if (!isBatchedTensor(tensor)) {
|
| 127 |
+
return nullptr;
|
| 128 |
+
}
|
| 129 |
+
return unsafeGetBatchedImpl(tensor);
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Returns a bitset. If bit i is set, then that means dim i is a batchdim.
|
| 133 |
+
inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
|
| 134 |
+
std::bitset<kVmapMaxTensorDims> is_bdim;
|
| 135 |
+
is_bdim.set(dim);
|
| 136 |
+
return is_bdim;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Creates a bitset for the given level
|
| 140 |
+
inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
|
| 141 |
+
std::bitset<kVmapNumLevels> result;
|
| 142 |
+
result.set(level);
|
| 143 |
+
return result;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
// Use this to construct a BatchedTensor from a regular Tensor
|
| 147 |
+
TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
|
| 148 |
+
|
| 149 |
+
// Adds a batch dim to `tensor`, returning a BatchedTensor
|
| 150 |
+
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
|
| 151 |
+
|
| 152 |
+
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
|
| 153 |
+
// any wrapper Tensor subclasses). This is because there are methods on Tensor
|
| 154 |
+
// that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
|
| 155 |
+
// TODO: should probably contain more (or all?) backend keys
|
| 156 |
+
constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
| 157 |
+
DispatchKey::Negative,
|
| 158 |
+
DispatchKey::Conjugate,
|
| 159 |
+
DispatchKey::XLA,
|
| 160 |
+
DispatchKey::CUDA,
|
| 161 |
+
DispatchKey::CPU,
|
| 162 |
+
});
|
| 163 |
+
|
| 164 |
+
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
|
| 165 |
+
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
|
| 166 |
+
return key_set & kKeysToPropagateToWrapper;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the BSD-style license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#pragma once
|
| 8 |
+
#include <ATen/Tensor.h>
|
| 9 |
+
#include <ATen/VmapGeneratedPlumbing.h>
|
| 10 |
+
|
| 11 |
+
// This file contains template metaprogramming things that are used for our
|
| 12 |
+
// batching rules.
|
| 13 |
+
//
|
| 14 |
+
// See NOTE: [vmap plumbing] for more details on why this is necessary.
|
| 15 |
+
// The plumbing has a bunch of metaprogramming hacks for determining the signature
|
| 16 |
+
// of a batching rule from the signature of the operator, many of which use the
|
| 17 |
+
// helper functions in this file.
|
| 18 |
+
|
| 19 |
+
namespace at::functorch {
|
| 20 |
+
|
| 21 |
+
// Metaprogramming things
|
| 22 |
+
template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
|
| 23 |
+
template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
|
| 24 |
+
template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
|
| 25 |
+
template <typename T> class debug_t;
|
| 26 |
+
|
| 27 |
+
// tail operation
|
| 28 |
+
template<class TypeList>
|
| 29 |
+
struct tail final {
|
| 30 |
+
static_assert(c10::guts::false_t<TypeList>::value,
|
| 31 |
+
"In typelist::tail<T>, the T argument must be typelist<...>.");
|
| 32 |
+
};
|
| 33 |
+
template<class Head, class... Tail>
|
| 34 |
+
struct tail<typelist<Head, Tail...>> final {
|
| 35 |
+
using type = typelist<Tail...>;
|
| 36 |
+
};
|
| 37 |
+
template<class TypeList> using tail_t = typename tail<TypeList>::type;
|
| 38 |
+
|
| 39 |
+
template <class First, class Second, class Next, class Tail>
|
| 40 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
|
| 41 |
+
using type = Next;
|
| 42 |
+
};
|
| 43 |
+
template <class Next, class Tail>
|
| 44 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, optional<int64_t>, Next, Tail> {
|
| 45 |
+
using type = Tail;
|
| 46 |
+
};
|
| 47 |
+
template <class Next, class Tail>
|
| 48 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, optional<int64_t>, Next, Tail> {
|
| 49 |
+
using type = Tail;
|
| 50 |
+
};
|
| 51 |
+
template <class Next, class Tail>
|
| 52 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, optional<int64_t>, Next, Tail> {
|
| 53 |
+
using type = Tail;
|
| 54 |
+
};
|
| 55 |
+
template <class Next, class Tail>
|
| 56 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>, optional<int64_t>, Next, Tail> {
|
| 57 |
+
using type = Tail;
|
| 58 |
+
};
|
| 59 |
+
template <class Next, class Tail>
|
| 60 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const optional<Tensor>&, optional<int64_t>, Next, Tail> {
|
| 61 |
+
using type = Tail;
|
| 62 |
+
};
|
| 63 |
+
template <class Next, class Tail>
|
| 64 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>&, optional<int64_t>, Next, Tail> {
|
| 65 |
+
using type = Tail;
|
| 66 |
+
};
|
| 67 |
+
template <class Next, class Tail>
|
| 68 |
+
struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, optional<int64_t>, Next, Tail> {
|
| 69 |
+
using type = Tail;
|
| 70 |
+
};
|
| 71 |
+
template <class TypeList> struct RemoveBatchDimAfterTensor {
|
| 72 |
+
using first = head_t<TypeList>;
|
| 73 |
+
using next = tail_t<TypeList>;
|
| 74 |
+
using second = head_t<next>;
|
| 75 |
+
using tail = tail_t<next>;
|
| 76 |
+
|
| 77 |
+
using type = concat_t<
|
| 78 |
+
typelist<first>,
|
| 79 |
+
typename RemoveBatchDimAfterTensor<
|
| 80 |
+
typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
|
| 81 |
+
>::type
|
| 82 |
+
>;
|
| 83 |
+
};
|
| 84 |
+
template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
|
| 85 |
+
using type = typelist<Type>;
|
| 86 |
+
};
|
| 87 |
+
template <> struct RemoveBatchDimAfterTensor<typelist<>> {
|
| 88 |
+
using type = typelist<>;
|
| 89 |
+
};
|
| 90 |
+
template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
|
| 91 |
+
|
| 92 |
+
template <typename T> struct UnpackSingleItemTuple {
|
| 93 |
+
using type = T;
|
| 94 |
+
};
|
| 95 |
+
template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
|
| 96 |
+
using type = T;
|
| 97 |
+
};
|
| 98 |
+
template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
|
| 99 |
+
|
| 100 |
+
template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
|
| 101 |
+
template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
|
| 102 |
+
using type = Return(Args...);
|
| 103 |
+
};
|
| 104 |
+
template <typename Return, typename TL>
|
| 105 |
+
struct BuildFunction {
|
| 106 |
+
using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
|
| 107 |
+
};
|
| 108 |
+
template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
template <typename batch_rule_t> struct ToOperatorType {
|
| 112 |
+
using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
|
| 113 |
+
using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
|
| 114 |
+
|
| 115 |
+
using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
|
| 116 |
+
using operator_return_type =
|
| 117 |
+
unpack_single_item_tuple_t<
|
| 118 |
+
c10::guts::typelist::to_tuple_t<
|
| 119 |
+
remove_batch_dim_after_tensor_t<
|
| 120 |
+
c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
|
| 121 |
+
|
| 122 |
+
using type = build_function_t<operator_return_type, operator_parameter_types>;
|
| 123 |
+
};
|
| 124 |
+
template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
|
| 125 |
+
|
| 126 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the BSD-style license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#pragma once
|
| 8 |
+
#include <ATen/functorch/Macros.h>
|
| 9 |
+
#include <c10/core/DispatchKey.h>
|
| 10 |
+
#include <ATen/core/function_schema.h>
|
| 11 |
+
#include <c10/util/Optional.h>
|
| 12 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 13 |
+
#include <ATen/functorch/Interpreter.h>
|
| 14 |
+
#include <ATen/functorch/VmapInterpreter.h>
|
| 15 |
+
#include <ATen/functorch/ADInterpreters.h>
|
| 16 |
+
#include <ATen/functorch/FunctionalizeInterpreter.h>
|
| 17 |
+
|
| 18 |
+
// Forward declared
|
| 19 |
+
namespace c10 { struct AutogradMetaInterface; }
|
| 20 |
+
|
| 21 |
+
namespace at::functorch {
|
| 22 |
+
|
| 23 |
+
// This file contains the implementation of functorch's interpreter stack.
|
| 24 |
+
// See NOTE: [functorch interpreter stack] first before reading on.
|
| 25 |
+
//
|
| 26 |
+
// NB: the functorch interpreter stack is also referred to as:
|
| 27 |
+
// - the "dynamic layer stack" -- an older name for "interpreter" was
|
| 28 |
+
// "dynamic layer".
|
| 29 |
+
// - the "functorch mode stack". You can think of each functorch transform as a
|
| 30 |
+
// "mode" (in the same sense as torch_dispatch mode or torch_function mode),
|
| 31 |
+
// and functorch being an implementation of a "mode stack" where the modes
|
| 32 |
+
// may be arbitrary composed.
|
| 33 |
+
|
| 34 |
+
// DynamicLayer is basically the same thing as an Interpreter.
|
| 35 |
+
// It represents a functorch transform and it holds an Interpreter,
|
| 36 |
+
// which contains metadata related to the transform and instructions on
|
| 37 |
+
// how to perform the transform.
|
| 38 |
+
//
|
| 39 |
+
// TODO: we can excise DynamicLayer in favor of Interpreter,
|
| 40 |
+
// But I am going to leave it for now as a compatiblity shim to avoid
|
| 41 |
+
// needing to refactor a lot of callsites...
|
| 42 |
+
struct TORCH_API DynamicLayer {
|
| 43 |
+
explicit DynamicLayer(
|
| 44 |
+
TransformType transform_type,
|
| 45 |
+
int64_t layerId,
|
| 46 |
+
optional<c10::SymInt> batchSize = nullopt,
|
| 47 |
+
optional<RandomnessType> randomness = nullopt,
|
| 48 |
+
optional<bool> prev_grad_mode = nullopt,
|
| 49 |
+
optional<bool> pre_fwd_grad_mode = nullopt,
|
| 50 |
+
optional<bool> functionalize_add_back_views = nullopt);
|
| 51 |
+
|
| 52 |
+
TransformType key() const;
|
| 53 |
+
int64_t layerId() const;
|
| 54 |
+
|
| 55 |
+
const Interpreter& interpreter() const { return interpreter_; }
|
| 56 |
+
Interpreter& interpreter() { return interpreter_; }
|
| 57 |
+
|
| 58 |
+
// Only valid for vmap
|
| 59 |
+
c10::SymInt batchSize() const;
|
| 60 |
+
RandomnessType randomness() const;
|
| 61 |
+
|
| 62 |
+
private:
|
| 63 |
+
Interpreter interpreter_;
|
| 64 |
+
};
|
| 65 |
+
|
| 66 |
+
TORCH_API int64_t initAndPushDynamicLayer(
|
| 67 |
+
TransformType transform_type,
|
| 68 |
+
optional<c10::SymInt> batch_size = nullopt,
|
| 69 |
+
optional<RandomnessType> randomness = nullopt,
|
| 70 |
+
optional<bool> prev_grad_mode = nullopt,
|
| 71 |
+
optional<bool> prev_fwd_grad_mode = nullopt,
|
| 72 |
+
optional<bool> functionalize_add_back_views = nullopt);
|
| 73 |
+
TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
|
| 74 |
+
TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
|
| 75 |
+
TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
|
| 76 |
+
TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
|
| 77 |
+
TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
|
| 78 |
+
|
| 79 |
+
// NOTE: [Life handles and lexically scoped transforms]
|
| 80 |
+
// functorch transforms are lexically scoped.
|
| 81 |
+
// Given a level, we store a "life handle" that is a boolean that tells us if the
|
| 82 |
+
// transform with that level is active or not.
|
| 83 |
+
//
|
| 84 |
+
// functorch's TensorWrapper (for grad transforms) stores a life handle.
|
| 85 |
+
// If a TensorWrapper escapes from the scope of the transform, then somehow
|
| 86 |
+
// it must know it escaped; it can tell by querying the life handle.
|
| 87 |
+
TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
|
| 88 |
+
|
| 89 |
+
// Returns if an operator is in-place. An operator is inplace if:
|
| 90 |
+
// 1. The first argument is a Tensor and it is being written to
|
| 91 |
+
// 2. The first argument is being returned
|
| 92 |
+
// 3. No other arguments are aliased
|
| 93 |
+
// Here is an example of an in-place operator:
|
| 94 |
+
// add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
|
| 95 |
+
TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
|
| 96 |
+
|
| 97 |
+
// Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
|
| 98 |
+
TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
|
| 99 |
+
|
| 100 |
+
TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
|
| 101 |
+
TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
|
| 102 |
+
|
| 103 |
+
// Pretty printers
|
| 104 |
+
TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
|
| 105 |
+
TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
|
| 106 |
+
|
| 107 |
+
// While a functorch transform is active, torch.autograd.function._SingleLevelFunction
|
| 108 |
+
// is disabled by default. The following two APIs are APIs for enabling
|
| 109 |
+
// it. These are not user-facing APIs. We can delete this in the future, but
|
| 110 |
+
// it is useful for debugging when something goes wrong with the
|
| 111 |
+
// autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
|
| 112 |
+
// because it leads to loud errors if something is incorrect.
|
| 113 |
+
TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
|
| 114 |
+
TORCH_API bool getSingleLevelAutogradFunctionAllowed();
|
| 115 |
+
|
| 116 |
+
// While a functorch grad transform is active, Tensor.requires_grad_() gets
|
| 117 |
+
// disabled. These two functions are the mechanism to controlling that.
|
| 118 |
+
TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
|
| 119 |
+
TORCH_API bool getInplaceRequiresGradAllowed();
|
| 120 |
+
|
| 121 |
+
TORCH_API DynamicLayer popDynamicLayer();
|
| 122 |
+
TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
|
| 123 |
+
|
| 124 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <ATen/functorch/Interpreter.h>
|
| 3 |
+
|
| 4 |
+
namespace at::functorch {
|
| 5 |
+
|
| 6 |
+
// This is the interpreter that handles the functionalize() transform.
|
| 7 |
+
// See NOTE: [functorch interpreter stack] for more details.
|
| 8 |
+
|
| 9 |
+
struct FunctionalizeInterpreterPtr {
|
| 10 |
+
explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
|
| 11 |
+
TransformType key() const { return base_->key(); }
|
| 12 |
+
int64_t level() const { return base_->level(); }
|
| 13 |
+
void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 14 |
+
void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 15 |
+
bool functionalizeAddBackViews() const {
|
| 16 |
+
return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
|
| 17 |
+
}
|
| 18 |
+
private:
|
| 19 |
+
const Interpreter* base_;
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/functorch/Macros.h>
|
| 4 |
+
#include <ATen/core/dispatch/Dispatcher.h>
|
| 5 |
+
#include <c10/core/impl/LocalDispatchKeySet.h>
|
| 6 |
+
#include <c10/util/Optional.h>
|
| 7 |
+
#include <bitset>
|
| 8 |
+
#include <utility>
|
| 9 |
+
#include <variant>
|
| 10 |
+
|
| 11 |
+
namespace at::functorch {
|
| 12 |
+
|
| 13 |
+
// NOTE: [functorch interpreter stack]
|
| 14 |
+
//
|
| 15 |
+
// functorch's dispatching system uses a stack of interpreters.
|
| 16 |
+
// Historically we've referred to this as the "DynamicLayerStack".
|
| 17 |
+
//
|
| 18 |
+
// An interpreter is something that reads in the code it is passed
|
| 19 |
+
// and then executes it. We have a different interpreter per-transform:
|
| 20 |
+
// the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
|
| 21 |
+
// and executing the batched version of it (the batching rule for aten::mv).
|
| 22 |
+
//
|
| 23 |
+
// Concretely, each interpreter is responsible for two things:
|
| 24 |
+
//
|
| 25 |
+
// 1) process(ophandle, stack)
|
| 26 |
+
// Given an operator handle and a stack of arguments, the interpreter is
|
| 27 |
+
// responsible for figuring out how to execute the operation under the semantics
|
| 28 |
+
// of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
|
| 29 |
+
// the batching rule.
|
| 30 |
+
//
|
| 31 |
+
// The batching rules are stored as kernels on the FuncTorchBatched key, so the way
|
| 32 |
+
// VmapInterpreter calls the batching rule is roughly: (A) exclude all
|
| 33 |
+
// dispatch keys aside from the Batched key, (B) redispatch so we get to the
|
| 34 |
+
// Batched key.
|
| 35 |
+
//
|
| 36 |
+
// 2) sendToNextInterpreter(ophandle, stack)
|
| 37 |
+
// The VmapInterpreter, when it sees aten::mv, will process it into a call to
|
| 38 |
+
// aten::mm. It then needs to send the call to aten::mm to the next interpreter
|
| 39 |
+
// in the interpreter stack.
|
| 40 |
+
//
|
| 41 |
+
// The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
|
| 42 |
+
// and most Interpreters will implement it this way.
|
| 43 |
+
|
| 44 |
+
enum class RandomnessType {
|
| 45 |
+
Error, // always errors when calling a random function
|
| 46 |
+
Same, // randomness appears the same across batches
|
| 47 |
+
Different, // randomness appears different across batches
|
| 48 |
+
END
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
enum class TransformType {
|
| 52 |
+
Torch, // Unused
|
| 53 |
+
Vmap,
|
| 54 |
+
Grad, // reverse-mode AD, aka vjp
|
| 55 |
+
Jvp, // forward-mode AD
|
| 56 |
+
Functionalize,
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
std::ostream& operator<<(std::ostream& os, const TransformType& t);
|
| 60 |
+
|
| 61 |
+
// NOTE: [Interpreter "subclassing" design]
|
| 62 |
+
//
|
| 63 |
+
// How are various Interpreters for different transforms (vmap, grad, ...)
|
| 64 |
+
// implemented?
|
| 65 |
+
//
|
| 66 |
+
// Accessing interpreters is in the hot-path of functorch so we have a constraint
|
| 67 |
+
// that this code must be as fast as possible.
|
| 68 |
+
//
|
| 69 |
+
// As a result, we stay away from virtual methods and this causes our code
|
| 70 |
+
// to look a little funny.
|
| 71 |
+
//
|
| 72 |
+
// `Interpreter` is the struct for Interpreters. It holds ALL of the
|
| 73 |
+
// relevant information (what type of interpreter it is and the metadata).
|
| 74 |
+
// Metadata for each interpreter is represented as a Union (std::variant)
|
| 75 |
+
// of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
|
| 76 |
+
//
|
| 77 |
+
// Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
|
| 78 |
+
// if you want to access the metadata fields (like batchSize and randomness).
|
| 79 |
+
//
|
| 80 |
+
// Each type of interpreter (e.g. Vmap) has a convenience struct
|
| 81 |
+
// (e.g. VmapInterpreterPtr) associated with it.
|
| 82 |
+
//
|
| 83 |
+
// Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
|
| 84 |
+
// and then one can access methods on VmapInterpreterPtr like so:
|
| 85 |
+
// >>> VmapInterpreterPtr(&interpreter).batchSize()
|
| 86 |
+
//
|
| 87 |
+
// Finally, Interpreter::process switches on the type of the interpreter
|
| 88 |
+
// and calls one of {Transform}Intepreter::processImpl under the hood.
|
| 89 |
+
// Same for Interpreter::sendToNextInterpreter :)
|
| 90 |
+
|
| 91 |
+
struct VmapInterpreterMeta {
|
| 92 |
+
explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
|
| 93 |
+
batchSize_(std::move(batchSize)), randomness_(randomness) {}
|
| 94 |
+
c10::SymInt batchSize_;
|
| 95 |
+
RandomnessType randomness_;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
struct GradInterpreterMeta {
|
| 99 |
+
explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
|
| 100 |
+
bool prevGradMode_;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
struct JvpInterpreterMeta {
|
| 104 |
+
explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
|
| 105 |
+
bool prevFwdGradMode_;
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
+
struct FunctionalizeInterpreterMeta {
|
| 109 |
+
explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
|
| 110 |
+
functionalizeAddBackViews_(functionalizeAddBackViews) {}
|
| 111 |
+
bool functionalizeAddBackViews_;
|
| 112 |
+
};
|
| 113 |
+
|
| 114 |
+
typedef std::variant<
|
| 115 |
+
int64_t,
|
| 116 |
+
GradInterpreterMeta,
|
| 117 |
+
JvpInterpreterMeta,
|
| 118 |
+
VmapInterpreterMeta,
|
| 119 |
+
FunctionalizeInterpreterMeta
|
| 120 |
+
> InterpreterMeta;
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
struct Interpreter {
|
| 124 |
+
// factory functions
|
| 125 |
+
static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
|
| 126 |
+
return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
|
| 127 |
+
}
|
| 128 |
+
static Interpreter Grad(int64_t level, bool prevGradMode) {
|
| 129 |
+
return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
|
| 130 |
+
}
|
| 131 |
+
static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
|
| 132 |
+
return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
|
| 133 |
+
}
|
| 134 |
+
static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
|
| 135 |
+
return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
// methods
|
| 139 |
+
TransformType key() const { return type_; }
|
| 140 |
+
int64_t level() const { return level_; }
|
| 141 |
+
const InterpreterMeta& meta() const { return meta_; }
|
| 142 |
+
|
| 143 |
+
void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 144 |
+
void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
|
| 145 |
+
|
| 146 |
+
void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
|
| 147 |
+
TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
|
| 148 |
+
savedLocalDispatchKeySet_ = keyset;
|
| 149 |
+
}
|
| 150 |
+
void clearSavedLocalDispatchKeySet() {
|
| 151 |
+
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
| 152 |
+
savedLocalDispatchKeySet_ = c10::nullopt;
|
| 153 |
+
}
|
| 154 |
+
c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
|
| 155 |
+
TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
|
| 156 |
+
return *savedLocalDispatchKeySet_;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
// An Interpreter is alive if we are currently inside the ongoing transform
|
| 160 |
+
// for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
|
| 161 |
+
// corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
|
| 162 |
+
bool is_alive() const {
|
| 163 |
+
return *is_alive_;
|
| 164 |
+
}
|
| 165 |
+
const std::shared_ptr<bool>& is_alive_ptr() const {
|
| 166 |
+
return is_alive_;
|
| 167 |
+
}
|
| 168 |
+
void set_is_alive(bool alive) {
|
| 169 |
+
*is_alive_ = alive;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
// Please don't use this
|
| 173 |
+
explicit Interpreter() = default;
|
| 174 |
+
|
| 175 |
+
private:
|
| 176 |
+
explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
|
| 177 |
+
type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
|
| 178 |
+
|
| 179 |
+
// fields
|
| 180 |
+
TransformType type_{};
|
| 181 |
+
int64_t level_{};
|
| 182 |
+
optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
|
| 183 |
+
std::shared_ptr<bool> is_alive_;
|
| 184 |
+
InterpreterMeta meta_;
|
| 185 |
+
};
|
| 186 |
+
|
| 187 |
+
// Applies the following for-loop:
|
| 188 |
+
// for i in range(begin, end):
|
| 189 |
+
// args[i] = func(args[i])
|
| 190 |
+
void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
|
| 191 |
+
std::function<Tensor(const Tensor&)> func);
|
| 192 |
+
|
| 193 |
+
// Applies the following for-loop:
|
| 194 |
+
// for i in range(begin, end):
|
| 195 |
+
// if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
|
| 196 |
+
// args[i] = func(args[i], i - begin, true)
|
| 197 |
+
// args[i] = func(args[i], i - begin)
|
| 198 |
+
void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
|
| 199 |
+
const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
|
| 200 |
+
|
| 201 |
+
std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
|
| 202 |
+
|
| 203 |
+
DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
|
| 204 |
+
|
| 205 |
+
void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
|
| 206 |
+
|
| 207 |
+
void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
|
| 208 |
+
|
| 209 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the BSD-style license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#pragma once
|
| 8 |
+
|
| 9 |
+
#include <ATen/functorch/Macros.h>
|
| 10 |
+
#include <ATen/functorch/BatchedTensorImpl.h>
|
| 11 |
+
|
| 12 |
+
namespace at::functorch {
|
| 13 |
+
|
| 14 |
+
// This files contains the legacy (now-deprecated) batching rule API.
|
| 15 |
+
// Please try to use the new-style batching rule API (see writing_batch_rules.md)
|
| 16 |
+
|
| 17 |
+
// This file contains abstractions used for transforming *logical* vmap arguments
|
| 18 |
+
// into *physical* arguments. (Keep reading for definitions of these terms).
|
| 19 |
+
|
| 20 |
+
// NOTE: [Logical vs physical args]
|
| 21 |
+
// Consider the following vmap.
|
| 22 |
+
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
|
| 23 |
+
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
|
| 24 |
+
// with batch dims 0 and 2:
|
| 25 |
+
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
|
| 26 |
+
//
|
| 27 |
+
// We say the *logical* view of the tensor has size [3] -- tensors inside
|
| 28 |
+
// `func` appear to have size [3].
|
| 29 |
+
// However, the *physical* underlying tensor (the one passed to vmap) has size
|
| 30 |
+
// [2, 3, 4].
|
| 31 |
+
//
|
| 32 |
+
// This notion of logical vs physical also extends to non-tensor arguments.
|
| 33 |
+
// Consider the previous tensor; let's assume the user called
|
| 34 |
+
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
|
| 35 |
+
// dimension they are reducing over is dim 0 but the physical dim is dim 1
|
| 36 |
+
// (the first non-batch dimension)
|
| 37 |
+
|
| 38 |
+
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
|
| 39 |
+
struct VmapPhysicalView;
|
| 40 |
+
|
| 41 |
+
// Most PyTorch operators take 4 or fewer inputs.
|
| 42 |
+
constexpr int64_t kVmapTransformStaticInputSize = 4;
|
| 43 |
+
using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
|
| 44 |
+
|
| 45 |
+
// Pytorch generally advertises good performance for <= 5 dims.
|
| 46 |
+
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
|
| 47 |
+
// dimensions to get 8. Adjust this number as necessary
|
| 48 |
+
constexpr int64_t kVmapStaticDimVecSize = 8;
|
| 49 |
+
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
|
| 50 |
+
using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
|
| 51 |
+
|
| 52 |
+
// NOTE: [What is an VmapTransform?]
|
| 53 |
+
// An *VmapTransform* converts logical views of tensors to physical views.
|
| 54 |
+
//
|
| 55 |
+
// Batching rules use VmapTransforms to convert logical arguments to
|
| 56 |
+
// physical arguments, then call one or more at:: operator that handles the
|
| 57 |
+
// physical arguments, and then converts the physical result back to a logical
|
| 58 |
+
// argument.
|
| 59 |
+
|
| 60 |
+
// VmapTransform for operators that take tensors with multiple batch dims.
|
| 61 |
+
// Given one or more logical views on Tensors, `logicalToPhysical`
|
| 62 |
+
// permutes all of the batch dims to the front of the tensor, aligns
|
| 63 |
+
// and expands the batch dims to match each other (according to their `level`),
|
| 64 |
+
// and returns a VmapPhysicalView on the tensor(s).
|
| 65 |
+
struct TORCH_API MultiBatchVmapTransform {
|
| 66 |
+
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
|
| 67 |
+
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
|
| 68 |
+
};
|
| 69 |
+
|
| 70 |
+
// VmapTransform for operators that broadcast all inputs.
|
| 71 |
+
// Given some logical views on Tensors, `logicalToPhysical`:
|
| 72 |
+
// - permutes all of the batch dims to the front of the tensors
|
| 73 |
+
// - aligns all the batch dims to the collective levels of all of the tensors.
|
| 74 |
+
// If a tensor does not have a batch dim for a vmap level, then it receives
|
| 75 |
+
// a size-one dimension for said level.
|
| 76 |
+
// - aligns the non-batch dims to have the same dimensionality, adding extra
|
| 77 |
+
// size-1 dimensions in between the batch dimensions and the non-batch dimensions
|
| 78 |
+
// so that the batch dimensions are lined up from the right.
|
| 79 |
+
//
|
| 80 |
+
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
|
| 81 |
+
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
|
| 82 |
+
// of size (B, 1, 2) and (B, 3, 2).
|
| 83 |
+
//
|
| 84 |
+
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
|
| 85 |
+
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
|
| 86 |
+
// actually *need* to return a tensor of size (1, 2) for the second tensor
|
| 87 |
+
// because the broadcasting operation takes care of that for us, but we do
|
| 88 |
+
// it anyways to keep things simple.
|
| 89 |
+
struct TORCH_API BroadcastingVmapTransform {
|
| 90 |
+
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
// Forward declared, if you're reading this file head to toe, don't worry about
|
| 94 |
+
// it yet.
|
| 95 |
+
struct VmapPhysicalToLogicalMap;
|
| 96 |
+
|
| 97 |
+
// NOTE: [What is a VmapPhysicalView?]
|
| 98 |
+
// VmapPhysicalView represents a physical view on a Tensor.
|
| 99 |
+
//
|
| 100 |
+
// One can use it to further convert logical dimension indices, logical shapes,
|
| 101 |
+
// and more to their physical variants, or convert a new (physical) tensor into
|
| 102 |
+
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
|
| 103 |
+
//
|
| 104 |
+
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
|
| 105 |
+
// the front and some levels that correspond to said batch dimensions.
|
| 106 |
+
//
|
| 107 |
+
// The levels bitset specifies which vmap levels correspond to the batch
|
| 108 |
+
// dimensions at the front of the tensor. In particular, the number of set bits
|
| 109 |
+
// corresponds to the number of batch dimensions on `tensor` and the rightmost
|
| 110 |
+
// bit of `levels` specifies the maximum number of nested vmaps we are in at
|
| 111 |
+
// this point in time.
|
| 112 |
+
// For example, given:
|
| 113 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
|
| 114 |
+
//
|
| 115 |
+
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
|
| 116 |
+
// than or equal to 3.
|
| 117 |
+
// bitset: 010100
|
| 118 |
+
// ^
|
| 119 |
+
// |
|
| 120 |
+
// levels: 012345
|
| 121 |
+
struct TORCH_API VmapPhysicalView {
|
| 122 |
+
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
|
| 123 |
+
: levels_(levels), tensor_(std::move(tensor)) {
|
| 124 |
+
// TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
Tensor& tensor() { return tensor_; }
|
| 128 |
+
const Tensor& tensor() const { return tensor_; }
|
| 129 |
+
|
| 130 |
+
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
|
| 131 |
+
//
|
| 132 |
+
// For example, given:
|
| 133 |
+
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
|
| 134 |
+
//
|
| 135 |
+
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
|
| 136 |
+
// This is because the size of levels tell us that the first two dimensions
|
| 137 |
+
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
|
| 138 |
+
// a physical dim of `n + 2`.
|
| 139 |
+
VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
|
| 140 |
+
int64_t getPhysicalDim(int64_t logical_dim) const;
|
| 141 |
+
|
| 142 |
+
// Returns a VmapPhysicalToLogicalMap object. This can be used for
|
| 143 |
+
// mapping a physical tensor to a new logical tensor (BatchedTensor)
|
| 144 |
+
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
|
| 145 |
+
|
| 146 |
+
// Maps a logical shape to a physical shape by pre-pending the batch
|
| 147 |
+
// sizes to the logical shape.
|
| 148 |
+
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
|
| 149 |
+
SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
|
| 150 |
+
|
| 151 |
+
int64_t numBatchDims() const;
|
| 152 |
+
|
| 153 |
+
private:
|
| 154 |
+
int64_t numLogicalDims() const;
|
| 155 |
+
|
| 156 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 157 |
+
Tensor tensor_;
|
| 158 |
+
};
|
| 159 |
+
|
| 160 |
+
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
|
| 161 |
+
// to a logical one (BatchedTensor). It holds some levels that are used to do the
|
| 162 |
+
// mapping and assumes that the batch dimensions in the physical tensor all
|
| 163 |
+
// occur at the front of the tensor.
|
| 164 |
+
struct TORCH_API VmapPhysicalToLogicalMap {
|
| 165 |
+
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
|
| 166 |
+
|
| 167 |
+
// Maps a physical tensor to a new logical tensor (BatchedTensor).
|
| 168 |
+
// Assumes that all of the "batch dimensions" are at the front
|
| 169 |
+
// of the physical tensor. For example, given:
|
| 170 |
+
// - x = rank-4 Tensor with size 2, 3, 5, 7
|
| 171 |
+
// - levels = (2, 4)
|
| 172 |
+
// Returns:
|
| 173 |
+
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
|
| 174 |
+
Tensor apply(const Tensor& physical_tensor) const;
|
| 175 |
+
|
| 176 |
+
// Given a vector of physical tensors,
|
| 177 |
+
// 1. maps each tensor to a new logical tensor. Assumes that all of the
|
| 178 |
+
// "batch dimensions" are at the front of the physical tensors.
|
| 179 |
+
// 2. stores the new logical tensors back into the passed-in vector. This is
|
| 180 |
+
// to avoid additional dynamic allocations.
|
| 181 |
+
void applyInplace(std::vector<Tensor>& physical_tensors) const;
|
| 182 |
+
|
| 183 |
+
std::bitset<kVmapNumLevels> levels_;
|
| 184 |
+
};
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
} // namespace at::functorch
|
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Macros.h
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#define SINGLE_ARG(...) __VA_ARGS__
|