xiaoanyu123 commited on
Commit
950f7f1
·
verified ·
1 Parent(s): 438ceb1

Add files using upload-large-folder tool

Browse files
Files changed (20) hide show
  1. pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__init__.py +0 -0
  2. pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc +0 -0
  3. pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_graphmatrix.py +276 -0
  4. pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_laplacian.py +336 -0
  5. pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_modularity.py +87 -0
  6. pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_spectrum.py +71 -0
  7. pythonProject/.venv/Lib/site-packages/numpy/ctypeslib.py +602 -0
  8. pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h +103 -0
  9. pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h +61 -0
  10. pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h +84 -0
  11. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h +38 -0
  12. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h +475 -0
  13. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h +81 -0
  14. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h +169 -0
  15. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h +126 -0
  16. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h +124 -0
  17. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h +22 -0
  18. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h +209 -0
  19. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h +187 -0
  20. pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Macros.h +3 -0
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__init__.py ADDED
File without changes
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/__pycache__/test_algebraic_connectivity.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_graphmatrix.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ np = pytest.importorskip("numpy")
4
+ pytest.importorskip("scipy")
5
+
6
+ import networkx as nx
7
+ from networkx.exception import NetworkXError
8
+ from networkx.generators.degree_seq import havel_hakimi_graph
9
+
10
+
11
+ def test_incidence_matrix_simple():
12
+ deg = [3, 2, 2, 1, 0]
13
+ G = havel_hakimi_graph(deg)
14
+ deg = [(1, 0), (1, 0), (1, 0), (2, 0), (1, 0), (2, 1), (0, 1), (0, 1)]
15
+ MG = nx.random_clustered_graph(deg, seed=42)
16
+
17
+ I = nx.incidence_matrix(G, dtype=int).todense()
18
+ # fmt: off
19
+ expected = np.array(
20
+ [[1, 1, 1, 0],
21
+ [0, 1, 0, 1],
22
+ [1, 0, 0, 1],
23
+ [0, 0, 1, 0],
24
+ [0, 0, 0, 0]]
25
+ )
26
+ # fmt: on
27
+ np.testing.assert_equal(I, expected)
28
+
29
+ I = nx.incidence_matrix(MG, dtype=int).todense()
30
+ # fmt: off
31
+ expected = np.array(
32
+ [[1, 0, 0, 0, 0, 0, 0],
33
+ [1, 0, 0, 0, 0, 0, 0],
34
+ [0, 1, 0, 0, 0, 0, 0],
35
+ [0, 0, 0, 0, 0, 0, 0],
36
+ [0, 1, 0, 0, 0, 0, 0],
37
+ [0, 0, 0, 0, 1, 1, 0],
38
+ [0, 0, 0, 0, 0, 1, 1],
39
+ [0, 0, 0, 0, 1, 0, 1]]
40
+ )
41
+ # fmt: on
42
+ np.testing.assert_equal(I, expected)
43
+
44
+ with pytest.raises(NetworkXError):
45
+ nx.incidence_matrix(G, nodelist=[0, 1])
46
+
47
+
48
+ class TestGraphMatrix:
49
+ @classmethod
50
+ def setup_class(cls):
51
+ deg = [3, 2, 2, 1, 0]
52
+ cls.G = havel_hakimi_graph(deg)
53
+ # fmt: off
54
+ cls.OI = np.array(
55
+ [[-1, -1, -1, 0],
56
+ [1, 0, 0, -1],
57
+ [0, 1, 0, 1],
58
+ [0, 0, 1, 0],
59
+ [0, 0, 0, 0]]
60
+ )
61
+ cls.A = np.array(
62
+ [[0, 1, 1, 1, 0],
63
+ [1, 0, 1, 0, 0],
64
+ [1, 1, 0, 0, 0],
65
+ [1, 0, 0, 0, 0],
66
+ [0, 0, 0, 0, 0]]
67
+ )
68
+ # fmt: on
69
+ cls.WG = havel_hakimi_graph(deg)
70
+ cls.WG.add_edges_from(
71
+ (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges()
72
+ )
73
+ # fmt: off
74
+ cls.WA = np.array(
75
+ [[0, 0.5, 0.5, 0.5, 0],
76
+ [0.5, 0, 0.5, 0, 0],
77
+ [0.5, 0.5, 0, 0, 0],
78
+ [0.5, 0, 0, 0, 0],
79
+ [0, 0, 0, 0, 0]]
80
+ )
81
+ # fmt: on
82
+ cls.MG = nx.MultiGraph(cls.G)
83
+ cls.MG2 = cls.MG.copy()
84
+ cls.MG2.add_edge(0, 1)
85
+ # fmt: off
86
+ cls.MG2A = np.array(
87
+ [[0, 2, 1, 1, 0],
88
+ [2, 0, 1, 0, 0],
89
+ [1, 1, 0, 0, 0],
90
+ [1, 0, 0, 0, 0],
91
+ [0, 0, 0, 0, 0]]
92
+ )
93
+ cls.MGOI = np.array(
94
+ [[-1, -1, -1, -1, 0],
95
+ [1, 1, 0, 0, -1],
96
+ [0, 0, 1, 0, 1],
97
+ [0, 0, 0, 1, 0],
98
+ [0, 0, 0, 0, 0]]
99
+ )
100
+ # fmt: on
101
+ cls.no_edges_G = nx.Graph([(1, 2), (3, 2, {"weight": 8})])
102
+ cls.no_edges_A = np.array([[0, 0], [0, 0]])
103
+
104
+ def test_incidence_matrix(self):
105
+ "Conversion to incidence matrix"
106
+ I = nx.incidence_matrix(
107
+ self.G,
108
+ nodelist=sorted(self.G),
109
+ edgelist=sorted(self.G.edges()),
110
+ oriented=True,
111
+ dtype=int,
112
+ ).todense()
113
+ np.testing.assert_equal(I, self.OI)
114
+
115
+ I = nx.incidence_matrix(
116
+ self.G,
117
+ nodelist=sorted(self.G),
118
+ edgelist=sorted(self.G.edges()),
119
+ oriented=False,
120
+ dtype=int,
121
+ ).todense()
122
+ np.testing.assert_equal(I, np.abs(self.OI))
123
+
124
+ I = nx.incidence_matrix(
125
+ self.MG,
126
+ nodelist=sorted(self.MG),
127
+ edgelist=sorted(self.MG.edges()),
128
+ oriented=True,
129
+ dtype=int,
130
+ ).todense()
131
+ np.testing.assert_equal(I, self.OI)
132
+
133
+ I = nx.incidence_matrix(
134
+ self.MG,
135
+ nodelist=sorted(self.MG),
136
+ edgelist=sorted(self.MG.edges()),
137
+ oriented=False,
138
+ dtype=int,
139
+ ).todense()
140
+ np.testing.assert_equal(I, np.abs(self.OI))
141
+
142
+ I = nx.incidence_matrix(
143
+ self.MG2,
144
+ nodelist=sorted(self.MG2),
145
+ edgelist=sorted(self.MG2.edges()),
146
+ oriented=True,
147
+ dtype=int,
148
+ ).todense()
149
+ np.testing.assert_equal(I, self.MGOI)
150
+
151
+ I = nx.incidence_matrix(
152
+ self.MG2,
153
+ nodelist=sorted(self.MG),
154
+ edgelist=sorted(self.MG2.edges()),
155
+ oriented=False,
156
+ dtype=int,
157
+ ).todense()
158
+ np.testing.assert_equal(I, np.abs(self.MGOI))
159
+
160
+ I = nx.incidence_matrix(self.G, dtype=np.uint8)
161
+ assert I.dtype == np.uint8
162
+
163
+ def test_weighted_incidence_matrix(self):
164
+ I = nx.incidence_matrix(
165
+ self.WG,
166
+ nodelist=sorted(self.WG),
167
+ edgelist=sorted(self.WG.edges()),
168
+ oriented=True,
169
+ dtype=int,
170
+ ).todense()
171
+ np.testing.assert_equal(I, self.OI)
172
+
173
+ I = nx.incidence_matrix(
174
+ self.WG,
175
+ nodelist=sorted(self.WG),
176
+ edgelist=sorted(self.WG.edges()),
177
+ oriented=False,
178
+ dtype=int,
179
+ ).todense()
180
+ np.testing.assert_equal(I, np.abs(self.OI))
181
+
182
+ # np.testing.assert_equal(nx.incidence_matrix(self.WG,oriented=True,
183
+ # weight='weight').todense(),0.5*self.OI)
184
+ # np.testing.assert_equal(nx.incidence_matrix(self.WG,weight='weight').todense(),
185
+ # np.abs(0.5*self.OI))
186
+ # np.testing.assert_equal(nx.incidence_matrix(self.WG,oriented=True,weight='other').todense(),
187
+ # 0.3*self.OI)
188
+
189
+ I = nx.incidence_matrix(
190
+ self.WG,
191
+ nodelist=sorted(self.WG),
192
+ edgelist=sorted(self.WG.edges()),
193
+ oriented=True,
194
+ weight="weight",
195
+ ).todense()
196
+ np.testing.assert_equal(I, 0.5 * self.OI)
197
+
198
+ I = nx.incidence_matrix(
199
+ self.WG,
200
+ nodelist=sorted(self.WG),
201
+ edgelist=sorted(self.WG.edges()),
202
+ oriented=False,
203
+ weight="weight",
204
+ ).todense()
205
+ np.testing.assert_equal(I, np.abs(0.5 * self.OI))
206
+
207
+ I = nx.incidence_matrix(
208
+ self.WG,
209
+ nodelist=sorted(self.WG),
210
+ edgelist=sorted(self.WG.edges()),
211
+ oriented=True,
212
+ weight="other",
213
+ ).todense()
214
+ np.testing.assert_equal(I, 0.3 * self.OI)
215
+
216
+ # WMG=nx.MultiGraph(self.WG)
217
+ # WMG.add_edge(0,1,weight=0.5,other=0.3)
218
+ # np.testing.assert_equal(nx.incidence_matrix(WMG,weight='weight').todense(),
219
+ # np.abs(0.5*self.MGOI))
220
+ # np.testing.assert_equal(nx.incidence_matrix(WMG,weight='weight',oriented=True).todense(),
221
+ # 0.5*self.MGOI)
222
+ # np.testing.assert_equal(nx.incidence_matrix(WMG,weight='other',oriented=True).todense(),
223
+ # 0.3*self.MGOI)
224
+
225
+ WMG = nx.MultiGraph(self.WG)
226
+ WMG.add_edge(0, 1, weight=0.5, other=0.3)
227
+
228
+ I = nx.incidence_matrix(
229
+ WMG,
230
+ nodelist=sorted(WMG),
231
+ edgelist=sorted(WMG.edges(keys=True)),
232
+ oriented=True,
233
+ weight="weight",
234
+ ).todense()
235
+ np.testing.assert_equal(I, 0.5 * self.MGOI)
236
+
237
+ I = nx.incidence_matrix(
238
+ WMG,
239
+ nodelist=sorted(WMG),
240
+ edgelist=sorted(WMG.edges(keys=True)),
241
+ oriented=False,
242
+ weight="weight",
243
+ ).todense()
244
+ np.testing.assert_equal(I, np.abs(0.5 * self.MGOI))
245
+
246
+ I = nx.incidence_matrix(
247
+ WMG,
248
+ nodelist=sorted(WMG),
249
+ edgelist=sorted(WMG.edges(keys=True)),
250
+ oriented=True,
251
+ weight="other",
252
+ ).todense()
253
+ np.testing.assert_equal(I, 0.3 * self.MGOI)
254
+
255
+ def test_adjacency_matrix(self):
256
+ "Conversion to adjacency matrix"
257
+ np.testing.assert_equal(nx.adjacency_matrix(self.G).todense(), self.A)
258
+ np.testing.assert_equal(nx.adjacency_matrix(self.MG).todense(), self.A)
259
+ np.testing.assert_equal(nx.adjacency_matrix(self.MG2).todense(), self.MG2A)
260
+ np.testing.assert_equal(
261
+ nx.adjacency_matrix(self.G, nodelist=[0, 1]).todense(), self.A[:2, :2]
262
+ )
263
+ np.testing.assert_equal(nx.adjacency_matrix(self.WG).todense(), self.WA)
264
+ np.testing.assert_equal(
265
+ nx.adjacency_matrix(self.WG, weight=None).todense(), self.A
266
+ )
267
+ np.testing.assert_equal(
268
+ nx.adjacency_matrix(self.MG2, weight=None).todense(), self.MG2A
269
+ )
270
+ np.testing.assert_equal(
271
+ nx.adjacency_matrix(self.WG, weight="other").todense(), 0.6 * self.WA
272
+ )
273
+ np.testing.assert_equal(
274
+ nx.adjacency_matrix(self.no_edges_G, nodelist=[1, 3]).todense(),
275
+ self.no_edges_A,
276
+ )
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_laplacian.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ np = pytest.importorskip("numpy")
4
+ pytest.importorskip("scipy")
5
+
6
+ import networkx as nx
7
+ from networkx.generators.degree_seq import havel_hakimi_graph
8
+ from networkx.generators.expanders import margulis_gabber_galil_graph
9
+
10
+
11
+ class TestLaplacian:
12
+ @classmethod
13
+ def setup_class(cls):
14
+ deg = [3, 2, 2, 1, 0]
15
+ cls.G = havel_hakimi_graph(deg)
16
+ cls.WG = nx.Graph(
17
+ (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges()
18
+ )
19
+ cls.WG.add_node(4)
20
+ cls.MG = nx.MultiGraph(cls.G)
21
+
22
+ # Graph with clsloops
23
+ cls.Gsl = cls.G.copy()
24
+ for node in cls.Gsl.nodes():
25
+ cls.Gsl.add_edge(node, node)
26
+
27
+ # Graph used as an example in Sec. 4.1 of Langville and Meyer,
28
+ # "Google's PageRank and Beyond".
29
+ cls.DiG = nx.DiGraph()
30
+ cls.DiG.add_edges_from(
31
+ (
32
+ (1, 2),
33
+ (1, 3),
34
+ (3, 1),
35
+ (3, 2),
36
+ (3, 5),
37
+ (4, 5),
38
+ (4, 6),
39
+ (5, 4),
40
+ (5, 6),
41
+ (6, 4),
42
+ )
43
+ )
44
+ cls.DiMG = nx.MultiDiGraph(cls.DiG)
45
+ cls.DiWG = nx.DiGraph(
46
+ (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.DiG.edges()
47
+ )
48
+ cls.DiGsl = cls.DiG.copy()
49
+ for node in cls.DiGsl.nodes():
50
+ cls.DiGsl.add_edge(node, node)
51
+
52
+ def test_laplacian(self):
53
+ "Graph Laplacian"
54
+ # fmt: off
55
+ NL = np.array([[ 3, -1, -1, -1, 0],
56
+ [-1, 2, -1, 0, 0],
57
+ [-1, -1, 2, 0, 0],
58
+ [-1, 0, 0, 1, 0],
59
+ [ 0, 0, 0, 0, 0]])
60
+ # fmt: on
61
+ WL = 0.5 * NL
62
+ OL = 0.3 * NL
63
+ # fmt: off
64
+ DiNL = np.array([[ 2, -1, -1, 0, 0, 0],
65
+ [ 0, 0, 0, 0, 0, 0],
66
+ [-1, -1, 3, -1, 0, 0],
67
+ [ 0, 0, 0, 2, -1, -1],
68
+ [ 0, 0, 0, -1, 2, -1],
69
+ [ 0, 0, 0, 0, -1, 1]])
70
+ # fmt: on
71
+ DiWL = 0.5 * DiNL
72
+ DiOL = 0.3 * DiNL
73
+ np.testing.assert_equal(nx.laplacian_matrix(self.G).todense(), NL)
74
+ np.testing.assert_equal(nx.laplacian_matrix(self.MG).todense(), NL)
75
+ np.testing.assert_equal(
76
+ nx.laplacian_matrix(self.G, nodelist=[0, 1]).todense(),
77
+ np.array([[1, -1], [-1, 1]]),
78
+ )
79
+ np.testing.assert_equal(nx.laplacian_matrix(self.WG).todense(), WL)
80
+ np.testing.assert_equal(nx.laplacian_matrix(self.WG, weight=None).todense(), NL)
81
+ np.testing.assert_equal(
82
+ nx.laplacian_matrix(self.WG, weight="other").todense(), OL
83
+ )
84
+
85
+ np.testing.assert_equal(nx.laplacian_matrix(self.DiG).todense(), DiNL)
86
+ np.testing.assert_equal(nx.laplacian_matrix(self.DiMG).todense(), DiNL)
87
+ np.testing.assert_equal(
88
+ nx.laplacian_matrix(self.DiG, nodelist=[1, 2]).todense(),
89
+ np.array([[1, -1], [0, 0]]),
90
+ )
91
+ np.testing.assert_equal(nx.laplacian_matrix(self.DiWG).todense(), DiWL)
92
+ np.testing.assert_equal(
93
+ nx.laplacian_matrix(self.DiWG, weight=None).todense(), DiNL
94
+ )
95
+ np.testing.assert_equal(
96
+ nx.laplacian_matrix(self.DiWG, weight="other").todense(), DiOL
97
+ )
98
+
99
+ def test_normalized_laplacian(self):
100
+ "Generalized Graph Laplacian"
101
+ # fmt: off
102
+ G = np.array([[ 1. , -0.408, -0.408, -0.577, 0.],
103
+ [-0.408, 1. , -0.5 , 0. , 0.],
104
+ [-0.408, -0.5 , 1. , 0. , 0.],
105
+ [-0.577, 0. , 0. , 1. , 0.],
106
+ [ 0. , 0. , 0. , 0. , 0.]])
107
+ GL = np.array([[ 1. , -0.408, -0.408, -0.577, 0. ],
108
+ [-0.408, 1. , -0.5 , 0. , 0. ],
109
+ [-0.408, -0.5 , 1. , 0. , 0. ],
110
+ [-0.577, 0. , 0. , 1. , 0. ],
111
+ [ 0. , 0. , 0. , 0. , 0. ]])
112
+ Lsl = np.array([[ 0.75 , -0.2887, -0.2887, -0.3536, 0. ],
113
+ [-0.2887, 0.6667, -0.3333, 0. , 0. ],
114
+ [-0.2887, -0.3333, 0.6667, 0. , 0. ],
115
+ [-0.3536, 0. , 0. , 0.5 , 0. ],
116
+ [ 0. , 0. , 0. , 0. , 0. ]])
117
+
118
+ DiG = np.array([[ 1. , 0. , -0.4082, 0. , 0. , 0. ],
119
+ [ 0. , 0. , 0. , 0. , 0. , 0. ],
120
+ [-0.4082, 0. , 1. , 0. , -0.4082, 0. ],
121
+ [ 0. , 0. , 0. , 1. , -0.5 , -0.7071],
122
+ [ 0. , 0. , 0. , -0.5 , 1. , -0.7071],
123
+ [ 0. , 0. , 0. , -0.7071, 0. , 1. ]])
124
+ DiGL = np.array([[ 1. , 0. , -0.4082, 0. , 0. , 0. ],
125
+ [ 0. , 0. , 0. , 0. , 0. , 0. ],
126
+ [-0.4082, 0. , 1. , -0.4082, 0. , 0. ],
127
+ [ 0. , 0. , 0. , 1. , -0.5 , -0.7071],
128
+ [ 0. , 0. , 0. , -0.5 , 1. , -0.7071],
129
+ [ 0. , 0. , 0. , 0. , -0.7071, 1. ]])
130
+ DiLsl = np.array([[ 0.6667, -0.5774, -0.2887, 0. , 0. , 0. ],
131
+ [ 0. , 0. , 0. , 0. , 0. , 0. ],
132
+ [-0.2887, -0.5 , 0.75 , -0.2887, 0. , 0. ],
133
+ [ 0. , 0. , 0. , 0.6667, -0.3333, -0.4082],
134
+ [ 0. , 0. , 0. , -0.3333, 0.6667, -0.4082],
135
+ [ 0. , 0. , 0. , 0. , -0.4082, 0.5 ]])
136
+ # fmt: on
137
+
138
+ np.testing.assert_almost_equal(
139
+ nx.normalized_laplacian_matrix(self.G, nodelist=range(5)).todense(),
140
+ G,
141
+ decimal=3,
142
+ )
143
+ np.testing.assert_almost_equal(
144
+ nx.normalized_laplacian_matrix(self.G).todense(), GL, decimal=3
145
+ )
146
+ np.testing.assert_almost_equal(
147
+ nx.normalized_laplacian_matrix(self.MG).todense(), GL, decimal=3
148
+ )
149
+ np.testing.assert_almost_equal(
150
+ nx.normalized_laplacian_matrix(self.WG).todense(), GL, decimal=3
151
+ )
152
+ np.testing.assert_almost_equal(
153
+ nx.normalized_laplacian_matrix(self.WG, weight="other").todense(),
154
+ GL,
155
+ decimal=3,
156
+ )
157
+ np.testing.assert_almost_equal(
158
+ nx.normalized_laplacian_matrix(self.Gsl).todense(), Lsl, decimal=3
159
+ )
160
+
161
+ np.testing.assert_almost_equal(
162
+ nx.normalized_laplacian_matrix(
163
+ self.DiG,
164
+ nodelist=range(1, 1 + 6),
165
+ ).todense(),
166
+ DiG,
167
+ decimal=3,
168
+ )
169
+ np.testing.assert_almost_equal(
170
+ nx.normalized_laplacian_matrix(self.DiG).todense(), DiGL, decimal=3
171
+ )
172
+ np.testing.assert_almost_equal(
173
+ nx.normalized_laplacian_matrix(self.DiMG).todense(), DiGL, decimal=3
174
+ )
175
+ np.testing.assert_almost_equal(
176
+ nx.normalized_laplacian_matrix(self.DiWG).todense(), DiGL, decimal=3
177
+ )
178
+ np.testing.assert_almost_equal(
179
+ nx.normalized_laplacian_matrix(self.DiWG, weight="other").todense(),
180
+ DiGL,
181
+ decimal=3,
182
+ )
183
+ np.testing.assert_almost_equal(
184
+ nx.normalized_laplacian_matrix(self.DiGsl).todense(), DiLsl, decimal=3
185
+ )
186
+
187
+
188
+ def test_directed_laplacian():
189
+ "Directed Laplacian"
190
+ # Graph used as an example in Sec. 4.1 of Langville and Meyer,
191
+ # "Google's PageRank and Beyond". The graph contains dangling nodes, so
192
+ # the pagerank random walk is selected by directed_laplacian
193
+ G = nx.DiGraph()
194
+ G.add_edges_from(
195
+ (
196
+ (1, 2),
197
+ (1, 3),
198
+ (3, 1),
199
+ (3, 2),
200
+ (3, 5),
201
+ (4, 5),
202
+ (4, 6),
203
+ (5, 4),
204
+ (5, 6),
205
+ (6, 4),
206
+ )
207
+ )
208
+ # fmt: off
209
+ GL = np.array([[ 0.9833, -0.2941, -0.3882, -0.0291, -0.0231, -0.0261],
210
+ [-0.2941, 0.8333, -0.2339, -0.0536, -0.0589, -0.0554],
211
+ [-0.3882, -0.2339, 0.9833, -0.0278, -0.0896, -0.0251],
212
+ [-0.0291, -0.0536, -0.0278, 0.9833, -0.4878, -0.6675],
213
+ [-0.0231, -0.0589, -0.0896, -0.4878, 0.9833, -0.2078],
214
+ [-0.0261, -0.0554, -0.0251, -0.6675, -0.2078, 0.9833]])
215
+ # fmt: on
216
+ L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G))
217
+ np.testing.assert_almost_equal(L, GL, decimal=3)
218
+
219
+ # Make the graph strongly connected, so we can use a random and lazy walk
220
+ G.add_edges_from(((2, 5), (6, 1)))
221
+ # fmt: off
222
+ GL = np.array([[ 1. , -0.3062, -0.4714, 0. , 0. , -0.3227],
223
+ [-0.3062, 1. , -0.1443, 0. , -0.3162, 0. ],
224
+ [-0.4714, -0.1443, 1. , 0. , -0.0913, 0. ],
225
+ [ 0. , 0. , 0. , 1. , -0.5 , -0.5 ],
226
+ [ 0. , -0.3162, -0.0913, -0.5 , 1. , -0.25 ],
227
+ [-0.3227, 0. , 0. , -0.5 , -0.25 , 1. ]])
228
+ # fmt: on
229
+ L = nx.directed_laplacian_matrix(
230
+ G, alpha=0.9, nodelist=sorted(G), walk_type="random"
231
+ )
232
+ np.testing.assert_almost_equal(L, GL, decimal=3)
233
+
234
+ # fmt: off
235
+ GL = np.array([[ 0.5 , -0.1531, -0.2357, 0. , 0. , -0.1614],
236
+ [-0.1531, 0.5 , -0.0722, 0. , -0.1581, 0. ],
237
+ [-0.2357, -0.0722, 0.5 , 0. , -0.0456, 0. ],
238
+ [ 0. , 0. , 0. , 0.5 , -0.25 , -0.25 ],
239
+ [ 0. , -0.1581, -0.0456, -0.25 , 0.5 , -0.125 ],
240
+ [-0.1614, 0. , 0. , -0.25 , -0.125 , 0.5 ]])
241
+ # fmt: on
242
+ L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G), walk_type="lazy")
243
+ np.testing.assert_almost_equal(L, GL, decimal=3)
244
+
245
+ # Make a strongly connected periodic graph
246
+ G = nx.DiGraph()
247
+ G.add_edges_from(((1, 2), (2, 4), (4, 1), (1, 3), (3, 4)))
248
+ # fmt: off
249
+ GL = np.array([[ 0.5 , -0.176, -0.176, -0.25 ],
250
+ [-0.176, 0.5 , 0. , -0.176],
251
+ [-0.176, 0. , 0.5 , -0.176],
252
+ [-0.25 , -0.176, -0.176, 0.5 ]])
253
+ # fmt: on
254
+ L = nx.directed_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G))
255
+ np.testing.assert_almost_equal(L, GL, decimal=3)
256
+
257
+
258
+ def test_directed_combinatorial_laplacian():
259
+ "Directed combinatorial Laplacian"
260
+ # Graph used as an example in Sec. 4.1 of Langville and Meyer,
261
+ # "Google's PageRank and Beyond". The graph contains dangling nodes, so
262
+ # the pagerank random walk is selected by directed_laplacian
263
+ G = nx.DiGraph()
264
+ G.add_edges_from(
265
+ (
266
+ (1, 2),
267
+ (1, 3),
268
+ (3, 1),
269
+ (3, 2),
270
+ (3, 5),
271
+ (4, 5),
272
+ (4, 6),
273
+ (5, 4),
274
+ (5, 6),
275
+ (6, 4),
276
+ )
277
+ )
278
+ # fmt: off
279
+ GL = np.array([[ 0.0366, -0.0132, -0.0153, -0.0034, -0.0020, -0.0027],
280
+ [-0.0132, 0.0450, -0.0111, -0.0076, -0.0062, -0.0069],
281
+ [-0.0153, -0.0111, 0.0408, -0.0035, -0.0083, -0.0027],
282
+ [-0.0034, -0.0076, -0.0035, 0.3688, -0.1356, -0.2187],
283
+ [-0.0020, -0.0062, -0.0083, -0.1356, 0.2026, -0.0505],
284
+ [-0.0027, -0.0069, -0.0027, -0.2187, -0.0505, 0.2815]])
285
+ # fmt: on
286
+
287
+ L = nx.directed_combinatorial_laplacian_matrix(G, alpha=0.9, nodelist=sorted(G))
288
+ np.testing.assert_almost_equal(L, GL, decimal=3)
289
+
290
+ # Make the graph strongly connected, so we can use a random and lazy walk
291
+ G.add_edges_from(((2, 5), (6, 1)))
292
+
293
+ # fmt: off
294
+ GL = np.array([[ 0.1395, -0.0349, -0.0465, 0. , 0. , -0.0581],
295
+ [-0.0349, 0.093 , -0.0116, 0. , -0.0465, 0. ],
296
+ [-0.0465, -0.0116, 0.0698, 0. , -0.0116, 0. ],
297
+ [ 0. , 0. , 0. , 0.2326, -0.1163, -0.1163],
298
+ [ 0. , -0.0465, -0.0116, -0.1163, 0.2326, -0.0581],
299
+ [-0.0581, 0. , 0. , -0.1163, -0.0581, 0.2326]])
300
+ # fmt: on
301
+
302
+ L = nx.directed_combinatorial_laplacian_matrix(
303
+ G, alpha=0.9, nodelist=sorted(G), walk_type="random"
304
+ )
305
+ np.testing.assert_almost_equal(L, GL, decimal=3)
306
+
307
+ # fmt: off
308
+ GL = np.array([[ 0.0698, -0.0174, -0.0233, 0. , 0. , -0.0291],
309
+ [-0.0174, 0.0465, -0.0058, 0. , -0.0233, 0. ],
310
+ [-0.0233, -0.0058, 0.0349, 0. , -0.0058, 0. ],
311
+ [ 0. , 0. , 0. , 0.1163, -0.0581, -0.0581],
312
+ [ 0. , -0.0233, -0.0058, -0.0581, 0.1163, -0.0291],
313
+ [-0.0291, 0. , 0. , -0.0581, -0.0291, 0.1163]])
314
+ # fmt: on
315
+
316
+ L = nx.directed_combinatorial_laplacian_matrix(
317
+ G, alpha=0.9, nodelist=sorted(G), walk_type="lazy"
318
+ )
319
+ np.testing.assert_almost_equal(L, GL, decimal=3)
320
+
321
+ E = nx.DiGraph(margulis_gabber_galil_graph(2))
322
+ L = nx.directed_combinatorial_laplacian_matrix(E)
323
+ # fmt: off
324
+ expected = np.array(
325
+ [[ 0.16666667, -0.08333333, -0.08333333, 0. ],
326
+ [-0.08333333, 0.16666667, 0. , -0.08333333],
327
+ [-0.08333333, 0. , 0.16666667, -0.08333333],
328
+ [ 0. , -0.08333333, -0.08333333, 0.16666667]]
329
+ )
330
+ # fmt: on
331
+ np.testing.assert_almost_equal(L, expected, decimal=6)
332
+
333
+ with pytest.raises(nx.NetworkXError):
334
+ nx.directed_combinatorial_laplacian_matrix(G, walk_type="pagerank", alpha=100)
335
+ with pytest.raises(nx.NetworkXError):
336
+ nx.directed_combinatorial_laplacian_matrix(G, walk_type="silly")
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_modularity.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ np = pytest.importorskip("numpy")
4
+ pytest.importorskip("scipy")
5
+
6
+ import networkx as nx
7
+ from networkx.generators.degree_seq import havel_hakimi_graph
8
+
9
+
10
+ class TestModularity:
11
+ @classmethod
12
+ def setup_class(cls):
13
+ deg = [3, 2, 2, 1, 0]
14
+ cls.G = havel_hakimi_graph(deg)
15
+ # Graph used as an example in Sec. 4.1 of Langville and Meyer,
16
+ # "Google's PageRank and Beyond". (Used for test_directed_laplacian)
17
+ cls.DG = nx.DiGraph()
18
+ cls.DG.add_edges_from(
19
+ (
20
+ (1, 2),
21
+ (1, 3),
22
+ (3, 1),
23
+ (3, 2),
24
+ (3, 5),
25
+ (4, 5),
26
+ (4, 6),
27
+ (5, 4),
28
+ (5, 6),
29
+ (6, 4),
30
+ )
31
+ )
32
+
33
+ def test_modularity(self):
34
+ "Modularity matrix"
35
+ # fmt: off
36
+ B = np.array([[-1.125, 0.25, 0.25, 0.625, 0.],
37
+ [0.25, -0.5, 0.5, -0.25, 0.],
38
+ [0.25, 0.5, -0.5, -0.25, 0.],
39
+ [0.625, -0.25, -0.25, -0.125, 0.],
40
+ [0., 0., 0., 0., 0.]])
41
+ # fmt: on
42
+
43
+ permutation = [4, 0, 1, 2, 3]
44
+ np.testing.assert_equal(nx.modularity_matrix(self.G), B)
45
+ np.testing.assert_equal(
46
+ nx.modularity_matrix(self.G, nodelist=permutation),
47
+ B[np.ix_(permutation, permutation)],
48
+ )
49
+
50
+ def test_modularity_weight(self):
51
+ "Modularity matrix with weights"
52
+ # fmt: off
53
+ B = np.array([[-1.125, 0.25, 0.25, 0.625, 0.],
54
+ [0.25, -0.5, 0.5, -0.25, 0.],
55
+ [0.25, 0.5, -0.5, -0.25, 0.],
56
+ [0.625, -0.25, -0.25, -0.125, 0.],
57
+ [0., 0., 0., 0., 0.]])
58
+ # fmt: on
59
+
60
+ G_weighted = self.G.copy()
61
+ for n1, n2 in G_weighted.edges():
62
+ G_weighted.edges[n1, n2]["weight"] = 0.5
63
+ # The following test would fail in networkx 1.1
64
+ np.testing.assert_equal(nx.modularity_matrix(G_weighted), B)
65
+ # The following test that the modularity matrix get rescaled accordingly
66
+ np.testing.assert_equal(
67
+ nx.modularity_matrix(G_weighted, weight="weight"), 0.5 * B
68
+ )
69
+
70
+ def test_directed_modularity(self):
71
+ "Directed Modularity matrix"
72
+ # fmt: off
73
+ B = np.array([[-0.2, 0.6, 0.8, -0.4, -0.4, -0.4],
74
+ [0., 0., 0., 0., 0., 0.],
75
+ [0.7, 0.4, -0.3, -0.6, 0.4, -0.6],
76
+ [-0.2, -0.4, -0.2, -0.4, 0.6, 0.6],
77
+ [-0.2, -0.4, -0.2, 0.6, -0.4, 0.6],
78
+ [-0.1, -0.2, -0.1, 0.8, -0.2, -0.2]])
79
+ # fmt: on
80
+ node_permutation = [5, 1, 2, 3, 4, 6]
81
+ idx_permutation = [4, 0, 1, 2, 3, 5]
82
+ mm = nx.directed_modularity_matrix(self.DG, nodelist=sorted(self.DG))
83
+ np.testing.assert_equal(mm, B)
84
+ np.testing.assert_equal(
85
+ nx.directed_modularity_matrix(self.DG, nodelist=node_permutation),
86
+ B[np.ix_(idx_permutation, idx_permutation)],
87
+ )
pythonProject/.venv/Lib/site-packages/networkx/linalg/tests/test_spectrum.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ np = pytest.importorskip("numpy")
4
+ pytest.importorskip("scipy")
5
+
6
+ import networkx as nx
7
+ from networkx.generators.degree_seq import havel_hakimi_graph
8
+
9
+
10
+ class TestSpectrum:
11
+ @classmethod
12
+ def setup_class(cls):
13
+ deg = [3, 2, 2, 1, 0]
14
+ cls.G = havel_hakimi_graph(deg)
15
+ cls.P = nx.path_graph(3)
16
+ cls.WG = nx.Graph(
17
+ (u, v, {"weight": 0.5, "other": 0.3}) for (u, v) in cls.G.edges()
18
+ )
19
+ cls.WG.add_node(4)
20
+ cls.DG = nx.DiGraph()
21
+ nx.add_path(cls.DG, [0, 1, 2])
22
+
23
+ def test_laplacian_spectrum(self):
24
+ "Laplacian eigenvalues"
25
+ evals = np.array([0, 0, 1, 3, 4])
26
+ e = sorted(nx.laplacian_spectrum(self.G))
27
+ np.testing.assert_almost_equal(e, evals)
28
+ e = sorted(nx.laplacian_spectrum(self.WG, weight=None))
29
+ np.testing.assert_almost_equal(e, evals)
30
+ e = sorted(nx.laplacian_spectrum(self.WG))
31
+ np.testing.assert_almost_equal(e, 0.5 * evals)
32
+ e = sorted(nx.laplacian_spectrum(self.WG, weight="other"))
33
+ np.testing.assert_almost_equal(e, 0.3 * evals)
34
+
35
+ def test_normalized_laplacian_spectrum(self):
36
+ "Normalized Laplacian eigenvalues"
37
+ evals = np.array([0, 0, 0.7712864461218, 1.5, 1.7287135538781])
38
+ e = sorted(nx.normalized_laplacian_spectrum(self.G))
39
+ np.testing.assert_almost_equal(e, evals)
40
+ e = sorted(nx.normalized_laplacian_spectrum(self.WG, weight=None))
41
+ np.testing.assert_almost_equal(e, evals)
42
+ e = sorted(nx.normalized_laplacian_spectrum(self.WG))
43
+ np.testing.assert_almost_equal(e, evals)
44
+ e = sorted(nx.normalized_laplacian_spectrum(self.WG, weight="other"))
45
+ np.testing.assert_almost_equal(e, evals)
46
+
47
+ def test_adjacency_spectrum(self):
48
+ "Adjacency eigenvalues"
49
+ evals = np.array([-np.sqrt(2), 0, np.sqrt(2)])
50
+ e = sorted(nx.adjacency_spectrum(self.P))
51
+ np.testing.assert_almost_equal(e, evals)
52
+
53
+ def test_modularity_spectrum(self):
54
+ "Modularity eigenvalues"
55
+ evals = np.array([-1.5, 0.0, 0.0])
56
+ e = sorted(nx.modularity_spectrum(self.P))
57
+ np.testing.assert_almost_equal(e, evals)
58
+ # Directed modularity eigenvalues
59
+ evals = np.array([-0.5, 0.0, 0.0])
60
+ e = sorted(nx.modularity_spectrum(self.DG))
61
+ np.testing.assert_almost_equal(e, evals)
62
+
63
+ def test_bethe_hessian_spectrum(self):
64
+ "Bethe Hessian eigenvalues"
65
+ evals = np.array([0.5 * (9 - np.sqrt(33)), 4, 0.5 * (9 + np.sqrt(33))])
66
+ e = sorted(nx.bethe_hessian_spectrum(self.P, r=2))
67
+ np.testing.assert_almost_equal(e, evals)
68
+ # Collapses back to Laplacian:
69
+ e1 = sorted(nx.bethe_hessian_spectrum(self.P, r=1))
70
+ e2 = sorted(nx.laplacian_spectrum(self.P))
71
+ np.testing.assert_almost_equal(e1, e2)
pythonProject/.venv/Lib/site-packages/numpy/ctypeslib.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ============================
3
+ ``ctypes`` Utility Functions
4
+ ============================
5
+
6
+ See Also
7
+ --------
8
+ load_library : Load a C library.
9
+ ndpointer : Array restype/argtype with verification.
10
+ as_ctypes : Create a ctypes array from an ndarray.
11
+ as_array : Create an ndarray from a ctypes array.
12
+
13
+ References
14
+ ----------
15
+ .. [1] "SciPy Cookbook: ctypes", https://scipy-cookbook.readthedocs.io/items/Ctypes.html
16
+
17
+ Examples
18
+ --------
19
+ Load the C library:
20
+
21
+ >>> _lib = np.ctypeslib.load_library('libmystuff', '.') #doctest: +SKIP
22
+
23
+ Our result type, an ndarray that must be of type double, be 1-dimensional
24
+ and is C-contiguous in memory:
25
+
26
+ >>> array_1d_double = np.ctypeslib.ndpointer(
27
+ ... dtype=np.double,
28
+ ... ndim=1, flags='CONTIGUOUS') #doctest: +SKIP
29
+
30
+ Our C-function typically takes an array and updates its values
31
+ in-place. For example::
32
+
33
+ void foo_func(double* x, int length)
34
+ {
35
+ int i;
36
+ for (i = 0; i < length; i++) {
37
+ x[i] = i*i;
38
+ }
39
+ }
40
+
41
+ We wrap it using:
42
+
43
+ >>> _lib.foo_func.restype = None #doctest: +SKIP
44
+ >>> _lib.foo_func.argtypes = [array_1d_double, c_int] #doctest: +SKIP
45
+
46
+ Then, we're ready to call ``foo_func``:
47
+
48
+ >>> out = np.empty(15, dtype=np.double)
49
+ >>> _lib.foo_func(out, len(out)) #doctest: +SKIP
50
+
51
+ """
52
+ __all__ = ['load_library', 'ndpointer', 'c_intp', 'as_ctypes', 'as_array',
53
+ 'as_ctypes_type']
54
+
55
+ import os
56
+ import numpy as np
57
+ from numpy._core.multiarray import _flagdict, flagsobj
58
+
59
+ try:
60
+ import ctypes
61
+ except ImportError:
62
+ ctypes = None
63
+
64
+ if ctypes is None:
65
+ def _dummy(*args, **kwds):
66
+ """
67
+ Dummy object that raises an ImportError if ctypes is not available.
68
+
69
+ Raises
70
+ ------
71
+ ImportError
72
+ If ctypes is not available.
73
+
74
+ """
75
+ raise ImportError("ctypes is not available.")
76
+ load_library = _dummy
77
+ as_ctypes = _dummy
78
+ as_array = _dummy
79
+ from numpy import intp as c_intp
80
+ _ndptr_base = object
81
+ else:
82
+ import numpy._core._internal as nic
83
+ c_intp = nic._getintp_ctype()
84
+ del nic
85
+ _ndptr_base = ctypes.c_void_p
86
+
87
+ # Adapted from Albert Strasheim
88
+ def load_library(libname, loader_path):
89
+ """
90
+ It is possible to load a library using
91
+
92
+ >>> lib = ctypes.cdll[<full_path_name>] # doctest: +SKIP
93
+
94
+ But there are cross-platform considerations, such as library file extensions,
95
+ plus the fact Windows will just load the first library it finds with that name.
96
+ NumPy supplies the load_library function as a convenience.
97
+
98
+ .. versionchanged:: 1.20.0
99
+ Allow libname and loader_path to take any
100
+ :term:`python:path-like object`.
101
+
102
+ Parameters
103
+ ----------
104
+ libname : path-like
105
+ Name of the library, which can have 'lib' as a prefix,
106
+ but without an extension.
107
+ loader_path : path-like
108
+ Where the library can be found.
109
+
110
+ Returns
111
+ -------
112
+ ctypes.cdll[libpath] : library object
113
+ A ctypes library object
114
+
115
+ Raises
116
+ ------
117
+ OSError
118
+ If there is no library with the expected extension, or the
119
+ library is defective and cannot be loaded.
120
+ """
121
+ # Convert path-like objects into strings
122
+ libname = os.fsdecode(libname)
123
+ loader_path = os.fsdecode(loader_path)
124
+
125
+ ext = os.path.splitext(libname)[1]
126
+ if not ext:
127
+ import sys
128
+ import sysconfig
129
+ # Try to load library with platform-specific name, otherwise
130
+ # default to libname.[so|dll|dylib]. Sometimes, these files are
131
+ # built erroneously on non-linux platforms.
132
+ base_ext = ".so"
133
+ if sys.platform.startswith("darwin"):
134
+ base_ext = ".dylib"
135
+ elif sys.platform.startswith("win"):
136
+ base_ext = ".dll"
137
+ libname_ext = [libname + base_ext]
138
+ so_ext = sysconfig.get_config_var("EXT_SUFFIX")
139
+ if not so_ext == base_ext:
140
+ libname_ext.insert(0, libname + so_ext)
141
+ else:
142
+ libname_ext = [libname]
143
+
144
+ loader_path = os.path.abspath(loader_path)
145
+ if not os.path.isdir(loader_path):
146
+ libdir = os.path.dirname(loader_path)
147
+ else:
148
+ libdir = loader_path
149
+
150
+ for ln in libname_ext:
151
+ libpath = os.path.join(libdir, ln)
152
+ if os.path.exists(libpath):
153
+ try:
154
+ return ctypes.cdll[libpath]
155
+ except OSError:
156
+ ## defective lib file
157
+ raise
158
+ ## if no successful return in the libname_ext loop:
159
+ raise OSError("no file with expected extension")
160
+
161
+
162
+ def _num_fromflags(flaglist):
163
+ num = 0
164
+ for val in flaglist:
165
+ num += _flagdict[val]
166
+ return num
167
+
168
+ _flagnames = ['C_CONTIGUOUS', 'F_CONTIGUOUS', 'ALIGNED', 'WRITEABLE',
169
+ 'OWNDATA', 'WRITEBACKIFCOPY']
170
+ def _flags_fromnum(num):
171
+ res = []
172
+ for key in _flagnames:
173
+ value = _flagdict[key]
174
+ if (num & value):
175
+ res.append(key)
176
+ return res
177
+
178
+
179
+ class _ndptr(_ndptr_base):
180
+ @classmethod
181
+ def from_param(cls, obj):
182
+ if not isinstance(obj, np.ndarray):
183
+ raise TypeError("argument must be an ndarray")
184
+ if cls._dtype_ is not None \
185
+ and obj.dtype != cls._dtype_:
186
+ raise TypeError("array must have data type %s" % cls._dtype_)
187
+ if cls._ndim_ is not None \
188
+ and obj.ndim != cls._ndim_:
189
+ raise TypeError("array must have %d dimension(s)" % cls._ndim_)
190
+ if cls._shape_ is not None \
191
+ and obj.shape != cls._shape_:
192
+ raise TypeError("array must have shape %s" % str(cls._shape_))
193
+ if cls._flags_ is not None \
194
+ and ((obj.flags.num & cls._flags_) != cls._flags_):
195
+ raise TypeError("array must have flags %s" %
196
+ _flags_fromnum(cls._flags_))
197
+ return obj.ctypes
198
+
199
+
200
+ class _concrete_ndptr(_ndptr):
201
+ """
202
+ Like _ndptr, but with `_shape_` and `_dtype_` specified.
203
+
204
+ Notably, this means the pointer has enough information to reconstruct
205
+ the array, which is not generally true.
206
+ """
207
+ def _check_retval_(self):
208
+ """
209
+ This method is called when this class is used as the .restype
210
+ attribute for a shared-library function, to automatically wrap the
211
+ pointer into an array.
212
+ """
213
+ return self.contents
214
+
215
+ @property
216
+ def contents(self):
217
+ """
218
+ Get an ndarray viewing the data pointed to by this pointer.
219
+
220
+ This mirrors the `contents` attribute of a normal ctypes pointer
221
+ """
222
+ full_dtype = np.dtype((self._dtype_, self._shape_))
223
+ full_ctype = ctypes.c_char * full_dtype.itemsize
224
+ buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents
225
+ return np.frombuffer(buffer, dtype=full_dtype).squeeze(axis=0)
226
+
227
+
228
+ # Factory for an array-checking class with from_param defined for
229
+ # use with ctypes argtypes mechanism
230
+ _pointer_type_cache = {}
231
+ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
232
+ """
233
+ Array-checking restype/argtypes.
234
+
235
+ An ndpointer instance is used to describe an ndarray in restypes
236
+ and argtypes specifications. This approach is more flexible than
237
+ using, for example, ``POINTER(c_double)``, since several restrictions
238
+ can be specified, which are verified upon calling the ctypes function.
239
+ These include data type, number of dimensions, shape and flags. If a
240
+ given array does not satisfy the specified restrictions,
241
+ a ``TypeError`` is raised.
242
+
243
+ Parameters
244
+ ----------
245
+ dtype : data-type, optional
246
+ Array data-type.
247
+ ndim : int, optional
248
+ Number of array dimensions.
249
+ shape : tuple of ints, optional
250
+ Array shape.
251
+ flags : str or tuple of str
252
+ Array flags; may be one or more of:
253
+
254
+ - C_CONTIGUOUS / C / CONTIGUOUS
255
+ - F_CONTIGUOUS / F / FORTRAN
256
+ - OWNDATA / O
257
+ - WRITEABLE / W
258
+ - ALIGNED / A
259
+ - WRITEBACKIFCOPY / X
260
+
261
+ Returns
262
+ -------
263
+ klass : ndpointer type object
264
+ A type object, which is an ``_ndtpr`` instance containing
265
+ dtype, ndim, shape and flags information.
266
+
267
+ Raises
268
+ ------
269
+ TypeError
270
+ If a given array does not satisfy the specified restrictions.
271
+
272
+ Examples
273
+ --------
274
+ >>> clib.somefunc.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64,
275
+ ... ndim=1,
276
+ ... flags='C_CONTIGUOUS')]
277
+ ... #doctest: +SKIP
278
+ >>> clib.somefunc(np.array([1, 2, 3], dtype=np.float64))
279
+ ... #doctest: +SKIP
280
+
281
+ """
282
+
283
+ # normalize dtype to dtype | None
284
+ if dtype is not None:
285
+ dtype = np.dtype(dtype)
286
+
287
+ # normalize flags to int | None
288
+ num = None
289
+ if flags is not None:
290
+ if isinstance(flags, str):
291
+ flags = flags.split(',')
292
+ elif isinstance(flags, (int, np.integer)):
293
+ num = flags
294
+ flags = _flags_fromnum(num)
295
+ elif isinstance(flags, flagsobj):
296
+ num = flags.num
297
+ flags = _flags_fromnum(num)
298
+ if num is None:
299
+ try:
300
+ flags = [x.strip().upper() for x in flags]
301
+ except Exception as e:
302
+ raise TypeError("invalid flags specification") from e
303
+ num = _num_fromflags(flags)
304
+
305
+ # normalize shape to tuple | None
306
+ if shape is not None:
307
+ try:
308
+ shape = tuple(shape)
309
+ except TypeError:
310
+ # single integer -> 1-tuple
311
+ shape = (shape,)
312
+
313
+ cache_key = (dtype, ndim, shape, num)
314
+
315
+ try:
316
+ return _pointer_type_cache[cache_key]
317
+ except KeyError:
318
+ pass
319
+
320
+ # produce a name for the new type
321
+ if dtype is None:
322
+ name = 'any'
323
+ elif dtype.names is not None:
324
+ name = str(id(dtype))
325
+ else:
326
+ name = dtype.str
327
+ if ndim is not None:
328
+ name += "_%dd" % ndim
329
+ if shape is not None:
330
+ name += "_"+"x".join(str(x) for x in shape)
331
+ if flags is not None:
332
+ name += "_"+"_".join(flags)
333
+
334
+ if dtype is not None and shape is not None:
335
+ base = _concrete_ndptr
336
+ else:
337
+ base = _ndptr
338
+
339
+ klass = type("ndpointer_%s"%name, (base,),
340
+ {"_dtype_": dtype,
341
+ "_shape_" : shape,
342
+ "_ndim_" : ndim,
343
+ "_flags_" : num})
344
+ _pointer_type_cache[cache_key] = klass
345
+ return klass
346
+
347
+
348
+ if ctypes is not None:
349
+ def _ctype_ndarray(element_type, shape):
350
+ """ Create an ndarray of the given element type and shape """
351
+ for dim in shape[::-1]:
352
+ element_type = dim * element_type
353
+ # prevent the type name include np.ctypeslib
354
+ element_type.__module__ = None
355
+ return element_type
356
+
357
+
358
+ def _get_scalar_type_map():
359
+ """
360
+ Return a dictionary mapping native endian scalar dtype to ctypes types
361
+ """
362
+ ct = ctypes
363
+ simple_types = [
364
+ ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong,
365
+ ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong,
366
+ ct.c_float, ct.c_double,
367
+ ct.c_bool,
368
+ ]
369
+ return {np.dtype(ctype): ctype for ctype in simple_types}
370
+
371
+
372
+ _scalar_type_map = _get_scalar_type_map()
373
+
374
+
375
+ def _ctype_from_dtype_scalar(dtype):
376
+ # swapping twice ensure that `=` is promoted to <, >, or |
377
+ dtype_with_endian = dtype.newbyteorder('S').newbyteorder('S')
378
+ dtype_native = dtype.newbyteorder('=')
379
+ try:
380
+ ctype = _scalar_type_map[dtype_native]
381
+ except KeyError as e:
382
+ raise NotImplementedError(
383
+ "Converting {!r} to a ctypes type".format(dtype)
384
+ ) from None
385
+
386
+ if dtype_with_endian.byteorder == '>':
387
+ ctype = ctype.__ctype_be__
388
+ elif dtype_with_endian.byteorder == '<':
389
+ ctype = ctype.__ctype_le__
390
+
391
+ return ctype
392
+
393
+
394
+ def _ctype_from_dtype_subarray(dtype):
395
+ element_dtype, shape = dtype.subdtype
396
+ ctype = _ctype_from_dtype(element_dtype)
397
+ return _ctype_ndarray(ctype, shape)
398
+
399
+
400
+ def _ctype_from_dtype_structured(dtype):
401
+ # extract offsets of each field
402
+ field_data = []
403
+ for name in dtype.names:
404
+ field_dtype, offset = dtype.fields[name][:2]
405
+ field_data.append((offset, name, _ctype_from_dtype(field_dtype)))
406
+
407
+ # ctypes doesn't care about field order
408
+ field_data = sorted(field_data, key=lambda f: f[0])
409
+
410
+ if len(field_data) > 1 and all(offset == 0 for offset, name, ctype in field_data):
411
+ # union, if multiple fields all at address 0
412
+ size = 0
413
+ _fields_ = []
414
+ for offset, name, ctype in field_data:
415
+ _fields_.append((name, ctype))
416
+ size = max(size, ctypes.sizeof(ctype))
417
+
418
+ # pad to the right size
419
+ if dtype.itemsize != size:
420
+ _fields_.append(('', ctypes.c_char * dtype.itemsize))
421
+
422
+ # we inserted manual padding, so always `_pack_`
423
+ return type('union', (ctypes.Union,), dict(
424
+ _fields_=_fields_,
425
+ _pack_=1,
426
+ __module__=None,
427
+ ))
428
+ else:
429
+ last_offset = 0
430
+ _fields_ = []
431
+ for offset, name, ctype in field_data:
432
+ padding = offset - last_offset
433
+ if padding < 0:
434
+ raise NotImplementedError("Overlapping fields")
435
+ if padding > 0:
436
+ _fields_.append(('', ctypes.c_char * padding))
437
+
438
+ _fields_.append((name, ctype))
439
+ last_offset = offset + ctypes.sizeof(ctype)
440
+
441
+
442
+ padding = dtype.itemsize - last_offset
443
+ if padding > 0:
444
+ _fields_.append(('', ctypes.c_char * padding))
445
+
446
+ # we inserted manual padding, so always `_pack_`
447
+ return type('struct', (ctypes.Structure,), dict(
448
+ _fields_=_fields_,
449
+ _pack_=1,
450
+ __module__=None,
451
+ ))
452
+
453
+
454
+ def _ctype_from_dtype(dtype):
455
+ if dtype.fields is not None:
456
+ return _ctype_from_dtype_structured(dtype)
457
+ elif dtype.subdtype is not None:
458
+ return _ctype_from_dtype_subarray(dtype)
459
+ else:
460
+ return _ctype_from_dtype_scalar(dtype)
461
+
462
+
463
+ def as_ctypes_type(dtype):
464
+ r"""
465
+ Convert a dtype into a ctypes type.
466
+
467
+ Parameters
468
+ ----------
469
+ dtype : dtype
470
+ The dtype to convert
471
+
472
+ Returns
473
+ -------
474
+ ctype
475
+ A ctype scalar, union, array, or struct
476
+
477
+ Raises
478
+ ------
479
+ NotImplementedError
480
+ If the conversion is not possible
481
+
482
+ Notes
483
+ -----
484
+ This function does not losslessly round-trip in either direction.
485
+
486
+ ``np.dtype(as_ctypes_type(dt))`` will:
487
+
488
+ - insert padding fields
489
+ - reorder fields to be sorted by offset
490
+ - discard field titles
491
+
492
+ ``as_ctypes_type(np.dtype(ctype))`` will:
493
+
494
+ - discard the class names of `ctypes.Structure`\ s and
495
+ `ctypes.Union`\ s
496
+ - convert single-element `ctypes.Union`\ s into single-element
497
+ `ctypes.Structure`\ s
498
+ - insert padding fields
499
+
500
+ Examples
501
+ --------
502
+ Converting a simple dtype:
503
+
504
+ >>> dt = np.dtype('int8')
505
+ >>> ctype = np.ctypeslib.as_ctypes_type(dt)
506
+ >>> ctype
507
+ <class 'ctypes.c_byte'>
508
+
509
+ Converting a structured dtype:
510
+
511
+ >>> dt = np.dtype([('x', 'i4'), ('y', 'f4')])
512
+ >>> ctype = np.ctypeslib.as_ctypes_type(dt)
513
+ >>> ctype
514
+ <class 'struct'>
515
+
516
+ """
517
+ return _ctype_from_dtype(np.dtype(dtype))
518
+
519
+
520
+ def as_array(obj, shape=None):
521
+ """
522
+ Create a numpy array from a ctypes array or POINTER.
523
+
524
+ The numpy array shares the memory with the ctypes object.
525
+
526
+ The shape parameter must be given if converting from a ctypes POINTER.
527
+ The shape parameter is ignored if converting from a ctypes array
528
+
529
+ Examples
530
+ --------
531
+ Converting a ctypes integer array:
532
+
533
+ >>> import ctypes
534
+ >>> ctypes_array = (ctypes.c_int * 5)(0, 1, 2, 3, 4)
535
+ >>> np_array = np.ctypeslib.as_array(ctypes_array)
536
+ >>> np_array
537
+ array([0, 1, 2, 3, 4], dtype=int32)
538
+
539
+ Converting a ctypes POINTER:
540
+
541
+ >>> import ctypes
542
+ >>> buffer = (ctypes.c_int * 5)(0, 1, 2, 3, 4)
543
+ >>> pointer = ctypes.cast(buffer, ctypes.POINTER(ctypes.c_int))
544
+ >>> np_array = np.ctypeslib.as_array(pointer, (5,))
545
+ >>> np_array
546
+ array([0, 1, 2, 3, 4], dtype=int32)
547
+
548
+ """
549
+ if isinstance(obj, ctypes._Pointer):
550
+ # convert pointers to an array of the desired shape
551
+ if shape is None:
552
+ raise TypeError(
553
+ 'as_array() requires a shape argument when called on a '
554
+ 'pointer')
555
+ p_arr_type = ctypes.POINTER(_ctype_ndarray(obj._type_, shape))
556
+ obj = ctypes.cast(obj, p_arr_type).contents
557
+
558
+ return np.asarray(obj)
559
+
560
+
561
+ def as_ctypes(obj):
562
+ """
563
+ Create and return a ctypes object from a numpy array. Actually
564
+ anything that exposes the __array_interface__ is accepted.
565
+
566
+ Examples
567
+ --------
568
+ Create ctypes object from inferred int ``np.array``:
569
+
570
+ >>> inferred_int_array = np.array([1, 2, 3])
571
+ >>> c_int_array = np.ctypeslib.as_ctypes(inferred_int_array)
572
+ >>> type(c_int_array)
573
+ <class 'c_long_Array_3'>
574
+ >>> c_int_array[:]
575
+ [1, 2, 3]
576
+
577
+ Create ctypes object from explicit 8 bit unsigned int ``np.array`` :
578
+
579
+ >>> exp_int_array = np.array([1, 2, 3], dtype=np.uint8)
580
+ >>> c_int_array = np.ctypeslib.as_ctypes(exp_int_array)
581
+ >>> type(c_int_array)
582
+ <class 'c_ubyte_Array_3'>
583
+ >>> c_int_array[:]
584
+ [1, 2, 3]
585
+
586
+ """
587
+ ai = obj.__array_interface__
588
+ if ai["strides"]:
589
+ raise TypeError("strided arrays not supported")
590
+ if ai["version"] != 3:
591
+ raise TypeError("only __array_interface__ version 3 supported")
592
+ addr, readonly = ai["data"]
593
+ if readonly:
594
+ raise TypeError("readonly arrays unsupported")
595
+
596
+ # can't use `_dtype((ai["typestr"], ai["shape"]))` here, as it overflows
597
+ # dtype.itemsize (gh-14214)
598
+ ctype_scalar = as_ctypes_type(ai["typestr"])
599
+ result_type = _ctype_ndarray(ctype_scalar, ai["shape"])
600
+ result = result_type.from_address(addr)
601
+ result.__keep = obj
602
+ return result
pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/MTIAHooksInterface.h ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Device.h>
4
+ #include <c10/util/Exception.h>
5
+
6
+ #include <c10/core/Stream.h>
7
+ #include <c10/util/Registry.h>
8
+
9
+ #include <ATen/detail/AcceleratorHooksInterface.h>
10
+
11
+ #include <string>
12
+
13
+ namespace at {
14
+ class Context;
15
+ }
16
+
17
+ namespace at {
18
+
19
+ constexpr const char* MTIA_HELP =
20
+ "The MTIA backend requires MTIA extension for PyTorch;"
21
+ "this error has occurred because you are trying "
22
+ "to use some MTIA's functionality without MTIA extension included.";
23
+
24
+ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
25
+ // this fails the implementation if MTIAHooks functions are called, but
26
+ // MTIA backend is not present.
27
+ #define FAIL_MTIAHOOKS_FUNC(func) \
28
+ TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
29
+
30
+ ~MTIAHooksInterface() override = default;
31
+
32
+ virtual void initMTIA() const {
33
+ // Avoid logging here, since MTIA needs init devices first then it will know
34
+ // how many devices are available. Make it as no-op if mtia extension is not
35
+ // dynamically loaded.
36
+ return;
37
+ }
38
+
39
+ virtual bool hasMTIA() const {
40
+ return false;
41
+ }
42
+
43
+ DeviceIndex deviceCount() const override {
44
+ return 0;
45
+ }
46
+
47
+ virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
48
+ FAIL_MTIAHOOKS_FUNC(__func__);
49
+ }
50
+
51
+ virtual std::string showConfig() const {
52
+ FAIL_MTIAHOOKS_FUNC(__func__);
53
+ }
54
+
55
+ bool hasPrimaryContext(DeviceIndex device_index) const override {
56
+ return false;
57
+ }
58
+
59
+ void setCurrentDevice(DeviceIndex device) const override {
60
+ FAIL_MTIAHOOKS_FUNC(__func__);
61
+ }
62
+
63
+ DeviceIndex getCurrentDevice() const override {
64
+ FAIL_MTIAHOOKS_FUNC(__func__);
65
+ return -1;
66
+ }
67
+
68
+ DeviceIndex exchangeDevice(DeviceIndex device) const override {
69
+ FAIL_MTIAHOOKS_FUNC(__func__);
70
+ return -1;
71
+ }
72
+
73
+ DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
74
+ FAIL_MTIAHOOKS_FUNC(__func__);
75
+ return -1;
76
+ }
77
+
78
+ virtual c10::Stream getCurrentStream(DeviceIndex device) const {
79
+ FAIL_MTIAHOOKS_FUNC(__func__);
80
+ return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
81
+ }
82
+
83
+ virtual c10::Stream getDefaultStream(DeviceIndex device) const {
84
+ FAIL_MTIAHOOKS_FUNC(__func__);
85
+ return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
86
+ }
87
+
88
+ virtual void setCurrentStream(const c10::Stream& stream) const {
89
+ FAIL_MTIAHOOKS_FUNC(__func__);
90
+ }
91
+ };
92
+
93
+ struct TORCH_API MTIAHooksArgs {};
94
+
95
+ C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
96
+ #define REGISTER_MTIA_HOOKS(clsname) \
97
+ C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname)
98
+
99
+ namespace detail {
100
+ TORCH_API const MTIAHooksInterface& getMTIAHooks();
101
+ TORCH_API bool isMTIAHooksBuilt();
102
+ } // namespace detail
103
+ } // namespace at
pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/PrivateUse1HooksInterface.h ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/core/Generator.h>
4
+ #include <ATen/detail/AcceleratorHooksInterface.h>
5
+ #include <c10/core/Allocator.h>
6
+ #include <c10/core/Device.h>
7
+ #include <c10/core/Storage.h>
8
+ #include <c10/util/Exception.h>
9
+ namespace at {
10
+
11
+ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
12
+ ~PrivateUse1HooksInterface() override = default;
13
+ virtual const at::Generator& getDefaultGenerator(
14
+ c10::DeviceIndex device_index) {
15
+ TORCH_CHECK_NOT_IMPLEMENTED(
16
+ false,
17
+ "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`.");
18
+ }
19
+
20
+ virtual at::Device getDeviceFromPtr(void* data) const {
21
+ TORCH_CHECK_NOT_IMPLEMENTED(
22
+ false,
23
+ "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`.");
24
+ }
25
+
26
+ virtual Allocator* getPinnedMemoryAllocator() const {
27
+ TORCH_CHECK(
28
+ false,
29
+ "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getPinnedMemoryAllocator`.");
30
+ }
31
+
32
+ bool hasPrimaryContext(DeviceIndex device_index) const override {
33
+ TORCH_CHECK_NOT_IMPLEMENTED(
34
+ false,
35
+ "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`.");
36
+ }
37
+
38
+ virtual void initPrivateUse1() const {}
39
+ virtual void resizePrivateUse1Bytes(const c10::Storage &storage, size_t newsize) const {
40
+ TORCH_CHECK_NOT_IMPLEMENTED(
41
+ false,
42
+ "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `resizePrivateUse1Bytes`.");
43
+ }
44
+ };
45
+
46
+ struct TORCH_API PrivateUse1HooksArgs {};
47
+
48
+ TORCH_API void RegisterPrivateUse1HooksInterface(
49
+ at::PrivateUse1HooksInterface* hook_);
50
+
51
+ TORCH_API at::PrivateUse1HooksInterface* GetPrivateUse1HooksInterface();
52
+
53
+ TORCH_API bool isPrivateUse1HooksRegistered();
54
+
55
+ namespace detail {
56
+
57
+ TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
58
+
59
+ } // namespace detail
60
+
61
+ } // namespace at
pythonProject/.venv/Lib/site-packages/torch/include/ATen/detail/XPUHooksInterface.h ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <c10/core/Device.h>
4
+ #include <c10/util/Exception.h>
5
+ #include <ATen/core/Generator.h>
6
+ #include <c10/util/Registry.h>
7
+
8
+ namespace at {
9
+
10
+ constexpr const char* XPU_HELP =
11
+ "The XPU backend requires Intel Extension for Pytorch;"
12
+ "this error has occurred because you are trying "
13
+ "to use some XPU's functionality, but the Intel Extension for Pytorch has not been "
14
+ "loaded for some reason. The Intel Extension for Pytorch MUST "
15
+ "be loaded, EVEN IF you don't directly use any symbols from that!";
16
+
17
+ struct TORCH_API XPUHooksInterface {
18
+ virtual ~XPUHooksInterface() = default;
19
+
20
+ virtual void initXPU() const {
21
+ TORCH_CHECK(
22
+ false,
23
+ "Cannot initialize XPU without Intel Extension for Pytorch.",
24
+ XPU_HELP);
25
+ }
26
+
27
+ virtual bool hasXPU() const {
28
+ return false;
29
+ }
30
+
31
+ virtual std::string showConfig() const {
32
+ TORCH_CHECK(
33
+ false,
34
+ "Cannot query detailed XPU version without Intel Extension for Pytorch. ",
35
+ XPU_HELP);
36
+ }
37
+
38
+ virtual int32_t getGlobalIdxFromDevice(const Device& device) const {
39
+ TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library.");
40
+ }
41
+
42
+ virtual Generator getXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
43
+ TORCH_CHECK(false, "Cannot get XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
44
+ }
45
+
46
+ virtual const Generator& getDefaultXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const {
47
+ TORCH_CHECK(false, "Cannot get default XPU generator without Intel Extension for Pytorch. ", XPU_HELP);
48
+ }
49
+
50
+ virtual DeviceIndex getNumGPUs() const {
51
+ return 0;
52
+ }
53
+
54
+ virtual DeviceIndex current_device() const {
55
+ TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library.");
56
+ }
57
+
58
+ virtual Device getDeviceFromPtr(void* /*data*/) const {
59
+ TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library.");
60
+ }
61
+
62
+ virtual void deviceSynchronize(DeviceIndex /*device_index*/) const {
63
+ TORCH_CHECK(false, "Cannot synchronize XPU device without ATen_xpu library.");
64
+ }
65
+
66
+ virtual Allocator* getPinnedMemoryAllocator() const {
67
+ TORCH_CHECK(false, "Cannot get XPU pinned memory allocator without ATen_xpu library.");
68
+ }
69
+
70
+ virtual bool isPinnedPtr(const void* /*data*/) const {
71
+ return false;
72
+ }
73
+ };
74
+
75
+ struct TORCH_API XPUHooksArgs {};
76
+
77
+ C10_DECLARE_REGISTRY(XPUHooksRegistry, XPUHooksInterface, XPUHooksArgs);
78
+ #define REGISTER_XPU_HOOKS(clsname) \
79
+ C10_REGISTER_CLASS(XPUHooksRegistry, clsname, clsname)
80
+
81
+ namespace detail {
82
+ TORCH_API const XPUHooksInterface& getXPUHooks();
83
+ } // namespace detail
84
+ } // namespace at
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/ADInterpreters.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/functorch/Interpreter.h>
3
+
4
+ namespace at::functorch {
5
+
6
+ // These are the interpreters for our AD transforms
7
+ // (grad, vjp and jvp).
8
+ // See NOTE: [functorch interpreter stack] for more details.
9
+
10
+ struct TORCH_API GradInterpreterPtr {
11
+ explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
12
+ TransformType key() const { return base_->key(); }
13
+ int64_t level() const { return base_->level(); }
14
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
15
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
16
+ bool prevGradMode() const {
17
+ return std::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
18
+ }
19
+ Tensor lift(const Tensor& tensor) const;
20
+ private:
21
+ const Interpreter* base_;
22
+ };
23
+
24
+ struct TORCH_API JvpInterpreterPtr {
25
+ explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
26
+ TransformType key() const { return base_->key(); }
27
+ int64_t level() const { return base_->level(); }
28
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
29
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
30
+ bool prevFwdGradMode() const {
31
+ return std::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
32
+ }
33
+ Tensor lift(const Tensor& tensor) const;
34
+ private:
35
+ const Interpreter* base_;
36
+ };
37
+
38
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchRulesHelper.h ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+ #pragma once
7
+
8
+ #include <c10/util/TypeList.h>
9
+
10
+ #include <ATen/ATen.h>
11
+ #include <ATen/Operators.h>
12
+
13
+ #include <ATen/functorch/DynamicLayer.h>
14
+ #include <ATen/functorch/TensorWrapper.h>
15
+ #include <ATen/functorch/BatchingMetaprogramming.h>
16
+ #include <ATen/functorch/LegacyVmapTransforms.h>
17
+ #include <ATen/functorch/BatchedFallback.h>
18
+ #include <ATen/functorch/PlumbingHelper.h>
19
+ #include <ATen/core/dispatch/Dispatcher.h>
20
+ #include <ATen/VmapGeneratedPlumbing.h>
21
+
22
+ #include <utility>
23
+
24
+ // This file contains helper functions for batching rules.
25
+
26
+ namespace at::functorch {
27
+
28
+ TORCH_API Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x);
29
+ TORCH_API Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x);
30
+
31
+ TORCH_API Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x);
32
+
33
+ Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
34
+ int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
35
+ int64_t numelWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
36
+ optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
37
+ int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
38
+ VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims);
39
+
40
+ void vmapIncompatibleInplaceError(const char* schema_name);
41
+
42
+ Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank);
43
+
44
+ void check_randomness(RandomnessType randomness);
45
+ void check_randomness(RandomnessType randomness, bool any_tensor_bdim);
46
+
47
+ inline Tensor ensure_has_bdim(const Tensor& tensor, bool has_bdim, c10::SymInt batch_size) {
48
+ if (has_bdim) {
49
+ return tensor;
50
+ }
51
+ const auto sizes = tensor.sym_sizes();
52
+ SymDimVector expanded_shape;
53
+ expanded_shape.reserve(sizes.size());
54
+ expanded_shape.emplace_back(std::move(batch_size));
55
+ expanded_shape.insert(expanded_shape.end(), sizes.begin(), sizes.end());
56
+ return tensor.expand_symint(expanded_shape);
57
+ }
58
+
59
+ #define VMAP_SUPPORT(op, batch_rule) \
60
+ m.impl(#op, op ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
61
+
62
+ #define VMAP_SUPPORT2(op, overload, batch_rule) \
63
+ m.impl(#op "." #overload, op ## _ ## overload ## _generated_plumbing<decltype(&batch_rule), &batch_rule>);
64
+
65
+ #define OP_DECOMPOSE(op) m.impl(#op, static_cast<decltype(&ATEN_FN(op))>(native::op));
66
+ #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast<decltype(&ATEN_FN2(op, overload))>(native::op));
67
+
68
+ // DO NOT USE ME DIRECTLY! Use BASIC_UNARY_BATCH_RULE to save yourself some pain
69
+ template <typename A, A a, typename C>
70
+ struct BasicUnaryBatchRuleHelper;
71
+
72
+ template <typename F, F Func, typename A, typename... T>
73
+ struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
74
+ static std::tuple<Tensor,optional<int64_t>> apply(
75
+ const Tensor& tensor,
76
+ optional<int64_t> batch_dim,
77
+ T... extra_args) {
78
+ return std::make_tuple(Func(tensor, std::forward<T>(extra_args)...), batch_dim);
79
+ }
80
+ };
81
+
82
+ // USAGE: BASIC_UNARY_BATCH_RULE(at::sin)
83
+ // INCORRECT USAGE: BASIC_UNARY_BATCH_RULE(&at::sin)
84
+ // It is important that this macro is not passed a function pointer!!
85
+ #define BASIC_UNARY_BATCH_RULE(fn) SINGLE_ARG(\
86
+ BasicUnaryBatchRuleHelper<\
87
+ decltype(&fn),\
88
+ &fn,\
89
+ c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
90
+
91
+ #define UNARY_POINTWISE(op) \
92
+ VMAP_SUPPORT(op, BASIC_UNARY_BATCH_RULE(ATEN_FN(op)));
93
+
94
+ template <typename A, A a, typename C>
95
+ struct VariadicBdimsBatchRuleHelper;
96
+
97
+ template <typename F, F Func, typename A, typename... T>
98
+ struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
99
+ static std::tuple<Tensor,optional<int64_t>> apply(
100
+ const Tensor& tensor,
101
+ optional<int64_t> batch_dim,
102
+ T... extra_args) {
103
+ auto tensor_ = moveBatchDimToFront(tensor, batch_dim);
104
+ return std::make_tuple(Func(tensor_, std::forward<T>(extra_args)...), 0);
105
+ }
106
+ };
107
+
108
+ // USAGE: VARIADIC_BDIMS_BATCH_RULE(at::cholesky_inverse)
109
+ // INCORRECT USAGE: VARIADIC_BDIMS_BATCH_RULE(&at::cholesky_inverse)
110
+ // It is important that this macro is not passed a function pointer!!
111
+ #define VARIADIC_BDIMS_BATCH_RULE(fn) SINGLE_ARG(\
112
+ VariadicBdimsBatchRuleHelper<\
113
+ decltype(&fn),\
114
+ &fn,\
115
+ c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
116
+
117
+ #define VARIADIC_BDIMS(op) \
118
+ VMAP_SUPPORT(op, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(op)));
119
+
120
+ #define VARIADIC_BDIMS2(op, overload) \
121
+ VMAP_SUPPORT2(op, overload, VARIADIC_BDIMS_BATCH_RULE(ATEN_FN2(op, overload)));
122
+
123
+ template<class F, F Func>
124
+ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
125
+ const auto& schema = op.schema();
126
+ const auto num_returns = schema.returns().size();
127
+ const auto num_arguments = schema.arguments().size();
128
+
129
+ c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
130
+ auto maybe_layer = maybeCurrentDynamicLayer();
131
+ vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
132
+
133
+ int64_t cur_level = maybe_layer->layerId();
134
+
135
+ auto orig_arguments = torch::jit::last(*stack, num_arguments);
136
+ if (std::none_of(orig_arguments.begin(), orig_arguments.end(), ivalueParticipatesInCurrentLevel)) {
137
+ op.callBoxed(stack);
138
+ return;
139
+ }
140
+
141
+ auto arguments = torch::jit::pop(*stack, num_arguments);
142
+ std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
143
+ std::vector<int64_t> tensor_pos;
144
+ for (const auto idx : c10::irange(0, num_arguments)) {
145
+ const auto& ivalue = arguments[idx];
146
+ if (ivalue.isTensor()) {
147
+ auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
148
+ tensor_inputs.emplace_back(tensor_value, tensor_bdim);
149
+ tensor_pos.push_back(static_cast<int64_t>(idx));
150
+ }
151
+ }
152
+ Func(tensor_inputs);
153
+
154
+ size_t tensor_idx = 0;
155
+ TORCH_INTERNAL_ASSERT(!tensor_pos.empty());
156
+ for (const auto arg_idx : c10::irange(0, num_arguments)) {
157
+ if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
158
+ torch::jit::push(stack, arguments[arg_idx]);
159
+ } else {
160
+ TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
161
+ torch::jit::push(stack, tensor_inputs[tensor_idx].first);
162
+ tensor_idx++;
163
+ }
164
+ }
165
+
166
+ op.callBoxed(stack);
167
+ const auto returns = torch::jit::pop(*stack, num_returns);
168
+ for (const auto& ret : returns) {
169
+ if (ret.isTensor()) {
170
+ torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
171
+ } else {
172
+ TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
173
+ }
174
+ }
175
+ }
176
+
177
+ inline void handle_pointwise_ops(std::vector<std::pair<Tensor, optional<int64_t>>> &tensor_inputs) {
178
+ int64_t out_logical_rank = 0;
179
+ for (auto& tensor_input : tensor_inputs) {
180
+ int64_t cur_logical_rank = rankWithoutBatchDim(tensor_input.first, tensor_input.second);
181
+ out_logical_rank = std::max(out_logical_rank, cur_logical_rank);
182
+ }
183
+ for (auto& tensor_input: tensor_inputs) {
184
+ tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
185
+ tensor_input.first = maybePadToLogicalRank(tensor_input.first, tensor_input.second, out_logical_rank);
186
+ }
187
+ }
188
+
189
+ #define POINTWISE_BOXED(op) \
190
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
191
+
192
+ #define POINTWISE_BOXED2(op, overload) \
193
+ m.impl(#op "." #overload, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_pointwise_ops), &handle_pointwise_ops>>());
194
+
195
+ inline void handle_variadic_bdims(std::vector<std::pair<Tensor, optional<int64_t>>> &tensor_inputs) {
196
+ for (auto & tensor_input : tensor_inputs) {
197
+ tensor_input.first = moveBatchDimToFront(tensor_input.first, tensor_input.second);
198
+ }
199
+ }
200
+
201
+ #define VARIADIC_BDIMS_BOXED(op) \
202
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_tensor_inputs_batch_rule<decltype(&handle_variadic_bdims), &handle_variadic_bdims>>());
203
+
204
+ using UnpackedBatchedTensor = std::tuple<Tensor,optional<int64_t>>;
205
+
206
+ inline void find_and_unpack_tensors(
207
+ const torch::jit::Stack* stack,
208
+ int64_t num_args,
209
+ int64_t cur_level,
210
+ SmallVector<UnpackedBatchedTensor, 5>* tensors,
211
+ SmallVector<int64_t, 5>* tensors_pos,
212
+ int64_t* batch_size) {
213
+
214
+ int64_t computed_batch_size = -1;
215
+ int64_t args_begin = static_cast<int64_t>(stack->size()) - num_args;
216
+
217
+ for (const auto idx : c10::irange(0, num_args)) {
218
+ const auto& ivalue = (*stack)[args_begin + idx];
219
+ if (!ivalue.isTensor()) {
220
+ continue;
221
+ }
222
+ auto unpacked = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
223
+ const auto& tensor_value = std::get<0>(unpacked);
224
+ const auto tensor_bdim = std::get<1>(unpacked);
225
+ if (tensor_bdim.has_value()) {
226
+ auto candidate_batch_size = tensor_value.size(*tensor_bdim);
227
+ if (computed_batch_size == -1) {
228
+ computed_batch_size = candidate_batch_size;
229
+ }
230
+ TORCH_INTERNAL_ASSERT(candidate_batch_size == computed_batch_size);
231
+ }
232
+
233
+ tensors->push_back(std::move(unpacked));
234
+ tensors_pos->push_back(idx);
235
+ }
236
+ TORCH_INTERNAL_ASSERT(computed_batch_size > -1);
237
+ *batch_size = computed_batch_size;
238
+ }
239
+
240
+ inline void boxed_existing_bdim_all_batch_rule(
241
+ const c10::OperatorHandle& op, torch::jit::Stack* stack) {
242
+ const auto& schema = op.schema();
243
+ const auto num_returns = schema.returns().size();
244
+ const auto num_arguments = static_cast<int64_t>(schema.arguments().size());
245
+
246
+ c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
247
+ auto maybe_layer = maybeCurrentDynamicLayer();
248
+ vmap_check_escaped(maybe_layer, "boxed_existing_bdim_all_batch_rule");
249
+ int64_t cur_level = maybe_layer->layerId();
250
+
251
+ const auto arguments = torch::jit::last(stack, num_arguments);
252
+ if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
253
+ op.callBoxed(stack);
254
+ return;
255
+ }
256
+
257
+ int64_t args_begin = static_cast<int64_t>(stack->size()) - num_arguments;
258
+ SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
259
+ SmallVector<int64_t, 5> tensor_pos;
260
+ int64_t batch_size = 0;
261
+
262
+ find_and_unpack_tensors(
263
+ stack, num_arguments, cur_level,
264
+ &tensor_inputs, &tensor_pos, &batch_size);
265
+
266
+ // for each tensor, ensure it has a bdim and reshape it.
267
+ for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
268
+ const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
269
+ auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
270
+ auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
271
+ if (!bdim.has_value()) {
272
+ bdim = 0;
273
+ }
274
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = reshape_dim_into(*bdim, 0, value_);
275
+ }
276
+
277
+ op.callBoxed(stack);
278
+
279
+ for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
280
+ const auto& ret = (*stack)[idx];
281
+ TORCH_INTERNAL_ASSERT(ret.isTensor(),
282
+ "This boxed batching rule does not currently support ops that return non-tensor values");
283
+ (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
284
+ }
285
+ }
286
+
287
+ // Use when all tensors arguments accept one (normal) batch dim.
288
+ // This batching rule expands the batch dim on all Tensors, reshapes it into
289
+ // dim 0, calls the op, and then reshapes the batch dim out of dim 0.
290
+ // This is not the most efficient thing; if there are alternatives, plese try
291
+ // to use them. Use this only as a last resort.
292
+ #define EXISTING_BDIM_ALL_BOXED(op) \
293
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_existing_bdim_all_batch_rule>());
294
+
295
+ template <int64_t feature_rank, int64_t contig_tensor_index=-1>
296
+ inline void boxed_all_tensors_have_optional_bdim(
297
+ const c10::OperatorHandle& op, torch::jit::Stack* stack) {
298
+ const auto& schema = op.schema();
299
+ const auto num_returns = schema.returns().size();
300
+ const auto num_arguments = schema.arguments().size();
301
+
302
+ c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
303
+ auto maybe_layer = maybeCurrentDynamicLayer();
304
+ vmap_check_escaped(maybe_layer, "boxed_all_tensors_have_optional_bdim");
305
+ int64_t cur_level = maybe_layer->layerId();
306
+
307
+ const auto arguments = torch::jit::last(stack, num_arguments);
308
+ if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
309
+ op.callBoxed(stack);
310
+ return;
311
+ }
312
+
313
+ int64_t args_begin = static_cast<int64_t>(stack->size() - num_arguments);
314
+ SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
315
+ SmallVector<int64_t, 5> tensor_pos;
316
+ int64_t batch_size = 0;
317
+
318
+ find_and_unpack_tensors(
319
+ stack, static_cast<int64_t>(num_arguments), cur_level,
320
+ &tensor_inputs, &tensor_pos, &batch_size);
321
+
322
+ optional<bool> is_no_batch_dim_case;
323
+
324
+ for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
325
+ const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
326
+ auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
327
+ const auto logical_rank = rankWithoutBatchDim(value, bdim);
328
+
329
+ if (!is_no_batch_dim_case.has_value()) {
330
+ is_no_batch_dim_case = (logical_rank == feature_rank);
331
+ }
332
+ auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
333
+ if (!bdim.has_value()) {
334
+ bdim = 0;
335
+ }
336
+ if (*is_no_batch_dim_case) {
337
+ TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
338
+ value_ = moveBatchDimToFront(value_, bdim);
339
+ if (tensor_idx == contig_tensor_index) {
340
+ value_ = value_.contiguous();
341
+ }
342
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
343
+ continue;
344
+ }
345
+ TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
346
+ value_ = reshape_dim_into(*bdim, 0, value_);
347
+ if (tensor_idx == contig_tensor_index) {
348
+ value_ = value_.contiguous();
349
+ }
350
+ (*stack)[args_begin + tensor_pos[tensor_idx]] = std::move(value_);
351
+ }
352
+
353
+ op.callBoxed(stack);
354
+
355
+ for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
356
+ const auto& ret = (*stack)[idx];
357
+ TORCH_INTERNAL_ASSERT(ret.isTensor(),
358
+ "This boxed batching rule does not currently support ops that return non-tensor values");
359
+ if (*is_no_batch_dim_case) {
360
+ (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
361
+ } else {
362
+ (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
363
+ }
364
+ }
365
+ }
366
+
367
+ // Useful for many NN operators.
368
+ // The operator must satisfy the following:
369
+ // - All arguments must accept an optional batch dim.
370
+ // - All arguments must be the same rank
371
+ #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED(feature_rank, op) \
372
+ m.impl(#op, torch::CppFunction::makeFromBoxedFunction<boxed_all_tensors_have_optional_bdim<feature_rank>>());
373
+
374
+ #define ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(feature_rank, op, contig_tensor_index) \
375
+ m.impl(#op, \
376
+ torch::CppFunction::makeFromBoxedFunction<\
377
+ boxed_all_tensors_have_optional_bdim<\
378
+ feature_rank, \
379
+ contig_tensor_index>\
380
+ >());
381
+
382
+ template <typename A, A a, typename C>
383
+ struct ExistingBdimBatchRuleHelper;
384
+
385
+ template <typename F, F Func, typename A, typename... T>
386
+ struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
387
+ static std::tuple<Tensor,optional<int64_t>> apply(
388
+ const Tensor& self,
389
+ optional<int64_t> self_bdim,
390
+ T... extra_args) {
391
+ auto self_ = reshape_dim_into(*self_bdim, 0, self);
392
+ auto out = Func(self_, std::forward<T>(extra_args)...);
393
+ return std::make_tuple(reshape_dim_outof_symint(0, self.sym_sizes()[*self_bdim], out), 0);
394
+ }
395
+ };
396
+
397
+ // USAGE: EXISTING_BDIM_BATCH_RULE(at::cholesky_inverse)
398
+ // INCORRECT USAGE: EXISTING_BDIM_BATCH_RULE(&at::cholesky_inverse)
399
+ // It is important that this macro is not passed a function pointer!!
400
+ #define EXISTING_BDIM_BATCH_RULE(fn) SINGLE_ARG(\
401
+ ExistingBdimBatchRuleHelper<\
402
+ decltype(&fn),\
403
+ &fn,\
404
+ c10::guts::function_traits<decltype(fn)>::parameter_types>::apply)
405
+
406
+
407
+ #define EXISTING_BDIM(op) \
408
+ VMAP_SUPPORT(op, EXISTING_BDIM_BATCH_RULE(ATEN_FN(op)));
409
+
410
+ #define EXISTING_BDIM2(op, overload) \
411
+ VMAP_SUPPORT2(op, overload, EXISTING_BDIM_BATCH_RULE(ATEN_FN2(op, overload)));
412
+
413
+ #define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
414
+
415
+
416
+ template <typename F, F Method, typename... ExtraArgs>
417
+ Tensor& unary_inplace_batch_rule(Tensor& self, optional<int64_t>, ExtraArgs... extra_args) {
418
+ INVOKE(self, Method)(std::forward<ExtraArgs>(extra_args)...);
419
+ return self;
420
+ }
421
+
422
+ inline int64_t get_bdim_size4(
423
+ const Tensor& a_value, optional<int64_t> a_bdim,
424
+ const Tensor& b_value, optional<int64_t> b_bdim,
425
+ const Tensor& c_value, optional<int64_t> c_bdim,
426
+ const Tensor& d_value, optional<int64_t> d_bdim) {
427
+ if (a_bdim)
428
+ return a_value.size(*a_bdim);
429
+ if (b_bdim)
430
+ return b_value.size(*b_bdim);
431
+ if (c_bdim)
432
+ return c_value.size(*c_bdim);
433
+ if (d_bdim)
434
+ return d_value.size(*d_bdim);
435
+ TORCH_INTERNAL_ASSERT(false);
436
+ }
437
+
438
+ inline int64_t get_bdim_size3(
439
+ const Tensor& a_value, optional<int64_t> a_bdim,
440
+ const Tensor& b_value, optional<int64_t> b_bdim,
441
+ const Tensor& c_value, optional<int64_t> c_bdim) {
442
+ if (a_bdim)
443
+ return a_value.size(*a_bdim);
444
+ if (b_bdim)
445
+ return b_value.size(*b_bdim);
446
+ if (c_bdim)
447
+ return c_value.size(*c_bdim);
448
+ TORCH_INTERNAL_ASSERT(false);
449
+ }
450
+
451
+ inline int64_t get_bdim_size2(
452
+ const Tensor& a_value, optional<int64_t> a_bdim,
453
+ const Tensor& b_value, optional<int64_t> b_bdim) {
454
+ if (a_bdim)
455
+ return a_value.size(*a_bdim);
456
+ if (b_bdim)
457
+ return b_value.size(*b_bdim);
458
+ TORCH_INTERNAL_ASSERT(false);
459
+ }
460
+
461
+ // [start, start + 1, ..., stop - 1]
462
+ inline VmapDimVector range(int64_t start, int64_t stop) {
463
+ TORCH_INTERNAL_ASSERT(stop >= start);
464
+ VmapDimVector dims;
465
+ dims.reserve(stop - start);
466
+ for (int64_t i = start; i < stop; i++) {
467
+ dims.emplace_back(i);
468
+ }
469
+ return dims;
470
+ }
471
+ std::tuple<Tensor, Tensor> _binary_pointwise_helper(
472
+ const Tensor& tensor, optional<int64_t> tensor_batch_dim, const Tensor& other, optional<int64_t> other_batch_dim,
473
+ bool do_type_promotion=true);
474
+
475
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedFallback.h ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+ #include <ATen/ATen.h>
9
+ #include <ATen/core/op_registration/op_registration.h>
10
+ #include <torch/library.h>
11
+
12
+ namespace at::functorch {
13
+
14
+ // This file contains code for the vmap fallback (also known as the
15
+ // BatchedTensor fallback or the Batched fallback). This code runs
16
+ // when an operation doesn't have a batching rule implemented.
17
+
18
+ // If an operator doesn't have a batching rule implemented then we fallback
19
+ // to this implementation. The fallback doesn't work on out= variants or
20
+ // view operations; that is, it works for out-of-place operations and
21
+ // in-place non-view operations.
22
+ //
23
+ // For out-of-place operations, the fallback effectively takes all of the
24
+ // BatchedTensors in `stack`, slices them, and runs `op` on all of the
25
+ // corresponding slices to produce slices of the outputs. The output slices
26
+ // then get `torch.stack`ed to create the
27
+ // final returns.
28
+ //
29
+ // The performance of the fallback is not very good because it introduces an
30
+ // extra copy from stacking the sliced outputs. Because of this, we prefer to
31
+ // write batching rules for operators whenever possible.
32
+ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
33
+ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
34
+
35
+ void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
36
+
37
+ // The vmap fallback emits a warning by default, but it may be disabled if
38
+ // the user finds it to be too annoying.
39
+ TORCH_API bool isVmapFallbackWarningEnabled();
40
+ TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
41
+
42
+ // Used for testing. The vmap fallback is enabled by default. When it is disabled,
43
+ // it raises an error.
44
+ TORCH_API bool isVmapFallbackEnabled();
45
+ TORCH_API void setVmapFallbackEnabled(bool enabled);
46
+
47
+ template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
48
+ return buffer[0].to<A>();
49
+ }
50
+ template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
51
+ return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
52
+ }
53
+ template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
54
+ return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
55
+ }
56
+
57
+ // slow_fallback is a way to call the vmap fallback inside some boxed kernel.
58
+ // There is probably some better way to metaprogram this.
59
+ template <typename Ret>
60
+ Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
61
+ std::vector<IValue> stack(args.begin(), args.end());
62
+ batchedTensorForLoopFallback(op, &stack);
63
+ return vector_to_result<Ret>(stack);
64
+ }
65
+
66
+ template <typename A, typename B>
67
+ std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
68
+ std::vector<IValue> stack(args.begin(), args.end());
69
+ batchedTensorForLoopFallback(op, &stack);
70
+ return vector_to_result<A, B>(stack);
71
+ }
72
+
73
+ template <typename A, typename B, typename C>
74
+ std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
75
+ std::vector<IValue> stack(args.begin(), args.end());
76
+ batchedTensorForLoopFallback(op, &stack);
77
+ return vector_to_result<A, B, C>(stack);
78
+ }
79
+
80
+
81
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchedTensorImpl.h ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <bitset>
10
+
11
+ #include <ATen/ArrayRef.h>
12
+ #include <ATen/SmallVector.h>
13
+ #include <ATen/Tensor.h>
14
+
15
+ namespace at::functorch {
16
+
17
+ using Tensor = at::Tensor;
18
+
19
+ // We assume this in a few other places in the codebase,
20
+ // but there isn't a centralized definition.
21
+ constexpr int64_t kVmapMaxTensorDims = 64;
22
+
23
+ // The valid vmap levels range from [0, 64). This effectively means that we
24
+ // support a maximum of 64 nested vmaps.
25
+ constexpr int64_t kVmapNumLevels = 64;
26
+
27
+ // Store this number of elements of BatchDims on the stack. Most people will
28
+ // probably use <= 5 nested vmaps, but adjust this number as necessary.
29
+ constexpr int64_t kBatchDimsStackSize = 5;
30
+
31
+ // A BatchedTensorImpl holds an underlying Tensor and a single batch dim
32
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
33
+ // BatchedTensorImpl.
34
+ //
35
+ // The batch dimensions are treated as being "private"; they are not user-visible.
36
+ // For example, in the following Tensor,
37
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
38
+ // dimension 0 is batch dimension.
39
+ //
40
+ // bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public)
41
+ // dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) tensor.
42
+ struct TORCH_API BatchedTensorImpl : public c10::TensorImpl {
43
+ explicit BatchedTensorImpl(at::DispatchKeySet key_set, Tensor value, int64_t dim, int64_t level);
44
+
45
+ // Returns batch dimension of this tensor
46
+ int64_t bdim() const { return bdim_; }
47
+
48
+ // Returns batch dimension of this tensor
49
+ int64_t level() const { return level_; }
50
+
51
+ // BatchedTensorImpl wraps a Tensor
52
+ const Tensor& value() const { return value_; }
53
+
54
+ // Given a public dimension index, return the dimension index in the underlying
55
+ // value() tensor.
56
+ // For example, if we have
57
+ // bt = BatchedTensorImpl(ones(2, 3, 5, 7), lvl=1, dim=0)
58
+ // bt.actualDim(0) -> 1
59
+ // bt.actualDim(1) -> 2
60
+ // bt.actualDim(2) -> 3
61
+ // bt.actualDim(3) -> Error
62
+ int64_t actualDim(int64_t dim, bool wrap_dim = true) const;
63
+
64
+ IntArrayRef sizes_custom() const override;
65
+ SymIntArrayRef sym_sizes_custom() const override;
66
+ int64_t size_custom(int64_t d) const override;
67
+ c10::SymInt sym_size_custom(int64_t d) const override;
68
+ // We have to override this because we opted into CustomStrides
69
+ IntArrayRef strides_custom() const override;
70
+ SymIntArrayRef sym_strides_custom() const override;
71
+ // Override a bunch of methods inherited from TensorImpl to return error messages.
72
+ bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override;
73
+ void set_size(int64_t dim, int64_t new_size) override;
74
+ void set_stride(int64_t dim, int64_t new_stride) override;
75
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
76
+ const c10::VariableVersion& version_counter,
77
+ bool allow_tensor_metadata_change) const override;
78
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
79
+ c10::VariableVersion&& version_counter,
80
+ bool allow_tensor_metadata_change) const override;
81
+ void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
82
+ #ifdef DEBUG
83
+ bool has_storage() const override;
84
+ #endif
85
+
86
+ void refreshTensorMetadata();
87
+
88
+ // Used in torchdim. torchdim uses non-lexical BatchedTensor; the way it
89
+ // accomplishes this is a hack where it is able to modify the levels of
90
+ // BatchedTensor to match the level of the current vmap transform.
91
+ void _unsafe_set_level(int64_t level) {
92
+ level_ = level;
93
+ }
94
+
95
+ // Used in batching rule for in-place view operations that can change
96
+ // the index of the bdim (think squeeze_, unsqueeze_)
97
+ void unsafe_set_bdim(int64_t bdim) {
98
+ // NB: you MUST call refreshTensorMetadata after doing this.
99
+ bdim_ = bdim;
100
+ }
101
+ private:
102
+ // see NOTE: [BatchedTensorImpl levels invariant]
103
+ void checkInvariants() const;
104
+ const char* tensorimpl_type_name() const override;
105
+
106
+ Tensor value_;
107
+
108
+ int64_t level_;
109
+ int64_t bdim_;
110
+ };
111
+
112
+ // NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a
113
+ // BatchedTensorImpl.
114
+ inline bool isBatchedTensor(const Tensor& tensor) {
115
+ return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::FuncTorchBatched) ||
116
+ tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::BatchedNestedTensor);
117
+ }
118
+
119
+ // It is unsafe to call this on a Tensor that is not backed by a
120
+ // BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible.
121
+ inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) {
122
+ return static_cast<BatchedTensorImpl*>(tensor.unsafeGetTensorImpl());
123
+ }
124
+
125
+ inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) {
126
+ if (!isBatchedTensor(tensor)) {
127
+ return nullptr;
128
+ }
129
+ return unsafeGetBatchedImpl(tensor);
130
+ }
131
+
132
+ // Returns a bitset. If bit i is set, then that means dim i is a batchdim.
133
+ inline std::bitset<kVmapMaxTensorDims> createBatchDimBitset(int64_t dim) {
134
+ std::bitset<kVmapMaxTensorDims> is_bdim;
135
+ is_bdim.set(dim);
136
+ return is_bdim;
137
+ }
138
+
139
+ // Creates a bitset for the given level
140
+ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
141
+ std::bitset<kVmapNumLevels> result;
142
+ result.set(level);
143
+ return result;
144
+ }
145
+
146
+ // Use this to construct a BatchedTensor from a regular Tensor
147
+ TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
148
+
149
+ // Adds a batch dim to `tensor`, returning a BatchedTensor
150
+ TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
151
+
152
+ // Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
153
+ // any wrapper Tensor subclasses). This is because there are methods on Tensor
154
+ // that skip dispatch and check for the presence of a dispatch key (e.g. is_cpu()).
155
+ // TODO: should probably contain more (or all?) backend keys
156
+ constexpr DispatchKeySet kKeysToPropagateToWrapper({
157
+ DispatchKey::Negative,
158
+ DispatchKey::Conjugate,
159
+ DispatchKey::XLA,
160
+ DispatchKey::CUDA,
161
+ DispatchKey::CPU,
162
+ });
163
+
164
+ inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {
165
+ auto key_set = tensor.unsafeGetTensorImpl()->key_set();
166
+ return key_set & kKeysToPropagateToWrapper;
167
+ }
168
+
169
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/BatchingMetaprogramming.h ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+ #include <ATen/Tensor.h>
9
+ #include <ATen/VmapGeneratedPlumbing.h>
10
+
11
+ // This file contains template metaprogramming things that are used for our
12
+ // batching rules.
13
+ //
14
+ // See NOTE: [vmap plumbing] for more details on why this is necessary.
15
+ // The plumbing has a bunch of metaprogramming hacks for determining the signature
16
+ // of a batching rule from the signature of the operator, many of which use the
17
+ // helper functions in this file.
18
+
19
+ namespace at::functorch {
20
+
21
+ // Metaprogramming things
22
+ template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
23
+ template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
24
+ template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
25
+ template <typename T> class debug_t;
26
+
27
+ // tail operation
28
+ template<class TypeList>
29
+ struct tail final {
30
+ static_assert(c10::guts::false_t<TypeList>::value,
31
+ "In typelist::tail<T>, the T argument must be typelist<...>.");
32
+ };
33
+ template<class Head, class... Tail>
34
+ struct tail<typelist<Head, Tail...>> final {
35
+ using type = typelist<Tail...>;
36
+ };
37
+ template<class TypeList> using tail_t = typename tail<TypeList>::type;
38
+
39
+ template <class First, class Second, class Next, class Tail>
40
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
41
+ using type = Next;
42
+ };
43
+ template <class Next, class Tail>
44
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, optional<int64_t>, Next, Tail> {
45
+ using type = Tail;
46
+ };
47
+ template <class Next, class Tail>
48
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, optional<int64_t>, Next, Tail> {
49
+ using type = Tail;
50
+ };
51
+ template <class Next, class Tail>
52
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, optional<int64_t>, Next, Tail> {
53
+ using type = Tail;
54
+ };
55
+ template <class Next, class Tail>
56
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>, optional<int64_t>, Next, Tail> {
57
+ using type = Tail;
58
+ };
59
+ template <class Next, class Tail>
60
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const optional<Tensor>&, optional<int64_t>, Next, Tail> {
61
+ using type = Tail;
62
+ };
63
+ template <class Next, class Tail>
64
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>&, optional<int64_t>, Next, Tail> {
65
+ using type = Tail;
66
+ };
67
+ template <class Next, class Tail>
68
+ struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, optional<int64_t>, Next, Tail> {
69
+ using type = Tail;
70
+ };
71
+ template <class TypeList> struct RemoveBatchDimAfterTensor {
72
+ using first = head_t<TypeList>;
73
+ using next = tail_t<TypeList>;
74
+ using second = head_t<next>;
75
+ using tail = tail_t<next>;
76
+
77
+ using type = concat_t<
78
+ typelist<first>,
79
+ typename RemoveBatchDimAfterTensor<
80
+ typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
81
+ >::type
82
+ >;
83
+ };
84
+ template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
85
+ using type = typelist<Type>;
86
+ };
87
+ template <> struct RemoveBatchDimAfterTensor<typelist<>> {
88
+ using type = typelist<>;
89
+ };
90
+ template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
91
+
92
+ template <typename T> struct UnpackSingleItemTuple {
93
+ using type = T;
94
+ };
95
+ template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
96
+ using type = T;
97
+ };
98
+ template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
99
+
100
+ template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
101
+ template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
102
+ using type = Return(Args...);
103
+ };
104
+ template <typename Return, typename TL>
105
+ struct BuildFunction {
106
+ using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
107
+ };
108
+ template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
109
+
110
+
111
+ template <typename batch_rule_t> struct ToOperatorType {
112
+ using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
113
+ using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
114
+
115
+ using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
116
+ using operator_return_type =
117
+ unpack_single_item_tuple_t<
118
+ c10::guts::typelist::to_tuple_t<
119
+ remove_batch_dim_after_tensor_t<
120
+ c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
121
+
122
+ using type = build_function_t<operator_return_type, operator_parameter_types>;
123
+ };
124
+ template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
125
+
126
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/DynamicLayer.h ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+ #include <ATen/functorch/Macros.h>
9
+ #include <c10/core/DispatchKey.h>
10
+ #include <ATen/core/function_schema.h>
11
+ #include <c10/util/Optional.h>
12
+ #include <c10/core/impl/LocalDispatchKeySet.h>
13
+ #include <ATen/functorch/Interpreter.h>
14
+ #include <ATen/functorch/VmapInterpreter.h>
15
+ #include <ATen/functorch/ADInterpreters.h>
16
+ #include <ATen/functorch/FunctionalizeInterpreter.h>
17
+
18
+ // Forward declared
19
+ namespace c10 { struct AutogradMetaInterface; }
20
+
21
+ namespace at::functorch {
22
+
23
+ // This file contains the implementation of functorch's interpreter stack.
24
+ // See NOTE: [functorch interpreter stack] first before reading on.
25
+ //
26
+ // NB: the functorch interpreter stack is also referred to as:
27
+ // - the "dynamic layer stack" -- an older name for "interpreter" was
28
+ // "dynamic layer".
29
+ // - the "functorch mode stack". You can think of each functorch transform as a
30
+ // "mode" (in the same sense as torch_dispatch mode or torch_function mode),
31
+ // and functorch being an implementation of a "mode stack" where the modes
32
+ // may be arbitrary composed.
33
+
34
+ // DynamicLayer is basically the same thing as an Interpreter.
35
+ // It represents a functorch transform and it holds an Interpreter,
36
+ // which contains metadata related to the transform and instructions on
37
+ // how to perform the transform.
38
+ //
39
+ // TODO: we can excise DynamicLayer in favor of Interpreter,
40
+ // But I am going to leave it for now as a compatiblity shim to avoid
41
+ // needing to refactor a lot of callsites...
42
+ struct TORCH_API DynamicLayer {
43
+ explicit DynamicLayer(
44
+ TransformType transform_type,
45
+ int64_t layerId,
46
+ optional<c10::SymInt> batchSize = nullopt,
47
+ optional<RandomnessType> randomness = nullopt,
48
+ optional<bool> prev_grad_mode = nullopt,
49
+ optional<bool> pre_fwd_grad_mode = nullopt,
50
+ optional<bool> functionalize_add_back_views = nullopt);
51
+
52
+ TransformType key() const;
53
+ int64_t layerId() const;
54
+
55
+ const Interpreter& interpreter() const { return interpreter_; }
56
+ Interpreter& interpreter() { return interpreter_; }
57
+
58
+ // Only valid for vmap
59
+ c10::SymInt batchSize() const;
60
+ RandomnessType randomness() const;
61
+
62
+ private:
63
+ Interpreter interpreter_;
64
+ };
65
+
66
+ TORCH_API int64_t initAndPushDynamicLayer(
67
+ TransformType transform_type,
68
+ optional<c10::SymInt> batch_size = nullopt,
69
+ optional<RandomnessType> randomness = nullopt,
70
+ optional<bool> prev_grad_mode = nullopt,
71
+ optional<bool> prev_fwd_grad_mode = nullopt,
72
+ optional<bool> functionalize_add_back_views = nullopt);
73
+ TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
74
+ TORCH_API std::optional<DynamicLayer> maybeCurrentDynamicLayer();
75
+ TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
76
+ TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
77
+ TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
78
+
79
+ // NOTE: [Life handles and lexically scoped transforms]
80
+ // functorch transforms are lexically scoped.
81
+ // Given a level, we store a "life handle" that is a boolean that tells us if the
82
+ // transform with that level is active or not.
83
+ //
84
+ // functorch's TensorWrapper (for grad transforms) stores a life handle.
85
+ // If a TensorWrapper escapes from the scope of the transform, then somehow
86
+ // it must know it escaped; it can tell by querying the life handle.
87
+ TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
88
+
89
+ // Returns if an operator is in-place. An operator is inplace if:
90
+ // 1. The first argument is a Tensor and it is being written to
91
+ // 2. The first argument is being returned
92
+ // 3. No other arguments are aliased
93
+ // Here is an example of an in-place operator:
94
+ // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
95
+ TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
96
+
97
+ // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
98
+ TORCH_API std::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
99
+
100
+ TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
101
+ TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
102
+
103
+ // Pretty printers
104
+ TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
105
+ TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
106
+
107
+ // While a functorch transform is active, torch.autograd.function._SingleLevelFunction
108
+ // is disabled by default. The following two APIs are APIs for enabling
109
+ // it. These are not user-facing APIs. We can delete this in the future, but
110
+ // it is useful for debugging when something goes wrong with the
111
+ // autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
112
+ // because it leads to loud errors if something is incorrect.
113
+ TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
114
+ TORCH_API bool getSingleLevelAutogradFunctionAllowed();
115
+
116
+ // While a functorch grad transform is active, Tensor.requires_grad_() gets
117
+ // disabled. These two functions are the mechanism to controlling that.
118
+ TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
119
+ TORCH_API bool getInplaceRequiresGradAllowed();
120
+
121
+ TORCH_API DynamicLayer popDynamicLayer();
122
+ TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
123
+
124
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/FunctionalizeInterpreter.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <ATen/functorch/Interpreter.h>
3
+
4
+ namespace at::functorch {
5
+
6
+ // This is the interpreter that handles the functionalize() transform.
7
+ // See NOTE: [functorch interpreter stack] for more details.
8
+
9
+ struct FunctionalizeInterpreterPtr {
10
+ explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
11
+ TransformType key() const { return base_->key(); }
12
+ int64_t level() const { return base_->level(); }
13
+ void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
14
+ void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
15
+ bool functionalizeAddBackViews() const {
16
+ return std::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
17
+ }
18
+ private:
19
+ const Interpreter* base_;
20
+ };
21
+
22
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Interpreter.h ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/functorch/Macros.h>
4
+ #include <ATen/core/dispatch/Dispatcher.h>
5
+ #include <c10/core/impl/LocalDispatchKeySet.h>
6
+ #include <c10/util/Optional.h>
7
+ #include <bitset>
8
+ #include <utility>
9
+ #include <variant>
10
+
11
+ namespace at::functorch {
12
+
13
+ // NOTE: [functorch interpreter stack]
14
+ //
15
+ // functorch's dispatching system uses a stack of interpreters.
16
+ // Historically we've referred to this as the "DynamicLayerStack".
17
+ //
18
+ // An interpreter is something that reads in the code it is passed
19
+ // and then executes it. We have a different interpreter per-transform:
20
+ // the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
21
+ // and executing the batched version of it (the batching rule for aten::mv).
22
+ //
23
+ // Concretely, each interpreter is responsible for two things:
24
+ //
25
+ // 1) process(ophandle, stack)
26
+ // Given an operator handle and a stack of arguments, the interpreter is
27
+ // responsible for figuring out how to execute the operation under the semantics
28
+ // of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
29
+ // the batching rule.
30
+ //
31
+ // The batching rules are stored as kernels on the FuncTorchBatched key, so the way
32
+ // VmapInterpreter calls the batching rule is roughly: (A) exclude all
33
+ // dispatch keys aside from the Batched key, (B) redispatch so we get to the
34
+ // Batched key.
35
+ //
36
+ // 2) sendToNextInterpreter(ophandle, stack)
37
+ // The VmapInterpreter, when it sees aten::mv, will process it into a call to
38
+ // aten::mm. It then needs to send the call to aten::mm to the next interpreter
39
+ // in the interpreter stack.
40
+ //
41
+ // The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
42
+ // and most Interpreters will implement it this way.
43
+
44
+ enum class RandomnessType {
45
+ Error, // always errors when calling a random function
46
+ Same, // randomness appears the same across batches
47
+ Different, // randomness appears different across batches
48
+ END
49
+ };
50
+
51
+ enum class TransformType {
52
+ Torch, // Unused
53
+ Vmap,
54
+ Grad, // reverse-mode AD, aka vjp
55
+ Jvp, // forward-mode AD
56
+ Functionalize,
57
+ };
58
+
59
+ std::ostream& operator<<(std::ostream& os, const TransformType& t);
60
+
61
+ // NOTE: [Interpreter "subclassing" design]
62
+ //
63
+ // How are various Interpreters for different transforms (vmap, grad, ...)
64
+ // implemented?
65
+ //
66
+ // Accessing interpreters is in the hot-path of functorch so we have a constraint
67
+ // that this code must be as fast as possible.
68
+ //
69
+ // As a result, we stay away from virtual methods and this causes our code
70
+ // to look a little funny.
71
+ //
72
+ // `Interpreter` is the struct for Interpreters. It holds ALL of the
73
+ // relevant information (what type of interpreter it is and the metadata).
74
+ // Metadata for each interpreter is represented as a Union (std::variant)
75
+ // of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
76
+ //
77
+ // Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
78
+ // if you want to access the metadata fields (like batchSize and randomness).
79
+ //
80
+ // Each type of interpreter (e.g. Vmap) has a convenience struct
81
+ // (e.g. VmapInterpreterPtr) associated with it.
82
+ //
83
+ // Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
84
+ // and then one can access methods on VmapInterpreterPtr like so:
85
+ // >>> VmapInterpreterPtr(&interpreter).batchSize()
86
+ //
87
+ // Finally, Interpreter::process switches on the type of the interpreter
88
+ // and calls one of {Transform}Intepreter::processImpl under the hood.
89
+ // Same for Interpreter::sendToNextInterpreter :)
90
+
91
+ struct VmapInterpreterMeta {
92
+ explicit VmapInterpreterMeta(c10::SymInt batchSize, RandomnessType randomness) :
93
+ batchSize_(std::move(batchSize)), randomness_(randomness) {}
94
+ c10::SymInt batchSize_;
95
+ RandomnessType randomness_;
96
+ };
97
+
98
+ struct GradInterpreterMeta {
99
+ explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
100
+ bool prevGradMode_;
101
+ };
102
+
103
+ struct JvpInterpreterMeta {
104
+ explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
105
+ bool prevFwdGradMode_;
106
+ };
107
+
108
+ struct FunctionalizeInterpreterMeta {
109
+ explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
110
+ functionalizeAddBackViews_(functionalizeAddBackViews) {}
111
+ bool functionalizeAddBackViews_;
112
+ };
113
+
114
+ typedef std::variant<
115
+ int64_t,
116
+ GradInterpreterMeta,
117
+ JvpInterpreterMeta,
118
+ VmapInterpreterMeta,
119
+ FunctionalizeInterpreterMeta
120
+ > InterpreterMeta;
121
+
122
+
123
+ struct Interpreter {
124
+ // factory functions
125
+ static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) {
126
+ return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness));
127
+ }
128
+ static Interpreter Grad(int64_t level, bool prevGradMode) {
129
+ return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
130
+ }
131
+ static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
132
+ return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
133
+ }
134
+ static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
135
+ return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
136
+ }
137
+
138
+ // methods
139
+ TransformType key() const { return type_; }
140
+ int64_t level() const { return level_; }
141
+ const InterpreterMeta& meta() const { return meta_; }
142
+
143
+ void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
144
+ void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
145
+
146
+ void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
147
+ TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
148
+ savedLocalDispatchKeySet_ = keyset;
149
+ }
150
+ void clearSavedLocalDispatchKeySet() {
151
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
152
+ savedLocalDispatchKeySet_ = c10::nullopt;
153
+ }
154
+ c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
155
+ TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
156
+ return *savedLocalDispatchKeySet_;
157
+ }
158
+
159
+ // An Interpreter is alive if we are currently inside the ongoing transform
160
+ // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
161
+ // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
162
+ bool is_alive() const {
163
+ return *is_alive_;
164
+ }
165
+ const std::shared_ptr<bool>& is_alive_ptr() const {
166
+ return is_alive_;
167
+ }
168
+ void set_is_alive(bool alive) {
169
+ *is_alive_ = alive;
170
+ }
171
+
172
+ // Please don't use this
173
+ explicit Interpreter() = default;
174
+
175
+ private:
176
+ explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
177
+ type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(std::move(meta)) {}
178
+
179
+ // fields
180
+ TransformType type_{};
181
+ int64_t level_{};
182
+ optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
183
+ std::shared_ptr<bool> is_alive_;
184
+ InterpreterMeta meta_;
185
+ };
186
+
187
+ // Applies the following for-loop:
188
+ // for i in range(begin, end):
189
+ // args[i] = func(args[i])
190
+ void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
191
+ std::function<Tensor(const Tensor&)> func);
192
+
193
+ // Applies the following for-loop:
194
+ // for i in range(begin, end):
195
+ // if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
196
+ // args[i] = func(args[i], i - begin, true)
197
+ // args[i] = func(args[i], i - begin)
198
+ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
199
+ const std::bitset<64> use_flag_relative, const std::function<Tensor(const Tensor&, bool)>& func);
200
+
201
+ std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
202
+
203
+ DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
204
+
205
+ void setup_dispatch_key_tls(TransformType key, DispatchKeySet include);
206
+
207
+ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
208
+
209
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/LegacyVmapTransforms.h ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Facebook, Inc. and its affiliates.
2
+ // All rights reserved.
3
+ //
4
+ // This source code is licensed under the BSD-style license found in the
5
+ // LICENSE file in the root directory of this source tree.
6
+
7
+ #pragma once
8
+
9
+ #include <ATen/functorch/Macros.h>
10
+ #include <ATen/functorch/BatchedTensorImpl.h>
11
+
12
+ namespace at::functorch {
13
+
14
+ // This files contains the legacy (now-deprecated) batching rule API.
15
+ // Please try to use the new-style batching rule API (see writing_batch_rules.md)
16
+
17
+ // This file contains abstractions used for transforming *logical* vmap arguments
18
+ // into *physical* arguments. (Keep reading for definitions of these terms).
19
+
20
+ // NOTE: [Logical vs physical args]
21
+ // Consider the following vmap.
22
+ // vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
23
+ // This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
24
+ // with batch dims 0 and 2:
25
+ // BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
26
+ //
27
+ // We say the *logical* view of the tensor has size [3] -- tensors inside
28
+ // `func` appear to have size [3].
29
+ // However, the *physical* underlying tensor (the one passed to vmap) has size
30
+ // [2, 3, 4].
31
+ //
32
+ // This notion of logical vs physical also extends to non-tensor arguments.
33
+ // Consider the previous tensor; let's assume the user called
34
+ // `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
35
+ // dimension they are reducing over is dim 0 but the physical dim is dim 1
36
+ // (the first non-batch dimension)
37
+
38
+ // Forward declared; see NOTE: [What is a VmapPhysicalView?]
39
+ struct VmapPhysicalView;
40
+
41
+ // Most PyTorch operators take 4 or fewer inputs.
42
+ constexpr int64_t kVmapTransformStaticInputSize = 4;
43
+ using VmapPhysicalViewVec = SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
44
+
45
+ // Pytorch generally advertises good performance for <= 5 dims.
46
+ // (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
47
+ // dimensions to get 8. Adjust this number as necessary
48
+ constexpr int64_t kVmapStaticDimVecSize = 8;
49
+ using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
50
+ using VmapSymDimVector = SmallVector<c10::SymInt, kVmapStaticDimVecSize>;
51
+
52
+ // NOTE: [What is an VmapTransform?]
53
+ // An *VmapTransform* converts logical views of tensors to physical views.
54
+ //
55
+ // Batching rules use VmapTransforms to convert logical arguments to
56
+ // physical arguments, then call one or more at:: operator that handles the
57
+ // physical arguments, and then converts the physical result back to a logical
58
+ // argument.
59
+
60
+ // VmapTransform for operators that take tensors with multiple batch dims.
61
+ // Given one or more logical views on Tensors, `logicalToPhysical`
62
+ // permutes all of the batch dims to the front of the tensor, aligns
63
+ // and expands the batch dims to match each other (according to their `level`),
64
+ // and returns a VmapPhysicalView on the tensor(s).
65
+ struct TORCH_API MultiBatchVmapTransform {
66
+ static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
67
+ static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
68
+ };
69
+
70
+ // VmapTransform for operators that broadcast all inputs.
71
+ // Given some logical views on Tensors, `logicalToPhysical`:
72
+ // - permutes all of the batch dims to the front of the tensors
73
+ // - aligns all the batch dims to the collective levels of all of the tensors.
74
+ // If a tensor does not have a batch dim for a vmap level, then it receives
75
+ // a size-one dimension for said level.
76
+ // - aligns the non-batch dims to have the same dimensionality, adding extra
77
+ // size-1 dimensions in between the batch dimensions and the non-batch dimensions
78
+ // so that the batch dimensions are lined up from the right.
79
+ //
80
+ // For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
81
+ // dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap tensors
82
+ // of size (B, 1, 2) and (B, 3, 2).
83
+ //
84
+ // Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
85
+ // VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
86
+ // actually *need* to return a tensor of size (1, 2) for the second tensor
87
+ // because the broadcasting operation takes care of that for us, but we do
88
+ // it anyways to keep things simple.
89
+ struct TORCH_API BroadcastingVmapTransform {
90
+ static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
91
+ };
92
+
93
+ // Forward declared, if you're reading this file head to toe, don't worry about
94
+ // it yet.
95
+ struct VmapPhysicalToLogicalMap;
96
+
97
+ // NOTE: [What is a VmapPhysicalView?]
98
+ // VmapPhysicalView represents a physical view on a Tensor.
99
+ //
100
+ // One can use it to further convert logical dimension indices, logical shapes,
101
+ // and more to their physical variants, or convert a new (physical) tensor into
102
+ // a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
103
+ //
104
+ // VmapPhysicalView stores a physical tensor with all of its batch dimensions at
105
+ // the front and some levels that correspond to said batch dimensions.
106
+ //
107
+ // The levels bitset specifies which vmap levels correspond to the batch
108
+ // dimensions at the front of the tensor. In particular, the number of set bits
109
+ // corresponds to the number of batch dimensions on `tensor` and the rightmost
110
+ // bit of `levels` specifies the maximum number of nested vmaps we are in at
111
+ // this point in time.
112
+ // For example, given:
113
+ // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
114
+ //
115
+ // Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
116
+ // than or equal to 3.
117
+ // bitset: 010100
118
+ // ^
119
+ // |
120
+ // levels: 012345
121
+ struct TORCH_API VmapPhysicalView {
122
+ VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
123
+ : levels_(levels), tensor_(std::move(tensor)) {
124
+ // TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
125
+ }
126
+
127
+ Tensor& tensor() { return tensor_; }
128
+ const Tensor& tensor() const { return tensor_; }
129
+
130
+ // Maps logical dim indices to physical dim indices. Also does dim wrapping.
131
+ //
132
+ // For example, given:
133
+ // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
134
+ //
135
+ // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
136
+ // This is because the size of levels tell us that the first two dimensions
137
+ // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
138
+ // a physical dim of `n + 2`.
139
+ VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
140
+ int64_t getPhysicalDim(int64_t logical_dim) const;
141
+
142
+ // Returns a VmapPhysicalToLogicalMap object. This can be used for
143
+ // mapping a physical tensor to a new logical tensor (BatchedTensor)
144
+ VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
145
+
146
+ // Maps a logical shape to a physical shape by pre-pending the batch
147
+ // sizes to the logical shape.
148
+ VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
149
+ SymDimVector getPhysicalShape(c10::SymIntArrayRef logical_shape) const;
150
+
151
+ int64_t numBatchDims() const;
152
+
153
+ private:
154
+ int64_t numLogicalDims() const;
155
+
156
+ std::bitset<kVmapNumLevels> levels_;
157
+ Tensor tensor_;
158
+ };
159
+
160
+ // Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
161
+ // to a logical one (BatchedTensor). It holds some levels that are used to do the
162
+ // mapping and assumes that the batch dimensions in the physical tensor all
163
+ // occur at the front of the tensor.
164
+ struct TORCH_API VmapPhysicalToLogicalMap {
165
+ VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels): levels_(levels) {}
166
+
167
+ // Maps a physical tensor to a new logical tensor (BatchedTensor).
168
+ // Assumes that all of the "batch dimensions" are at the front
169
+ // of the physical tensor. For example, given:
170
+ // - x = rank-4 Tensor with size 2, 3, 5, 7
171
+ // - levels = (2, 4)
172
+ // Returns:
173
+ // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
174
+ Tensor apply(const Tensor& physical_tensor) const;
175
+
176
+ // Given a vector of physical tensors,
177
+ // 1. maps each tensor to a new logical tensor. Assumes that all of the
178
+ // "batch dimensions" are at the front of the physical tensors.
179
+ // 2. stores the new logical tensors back into the passed-in vector. This is
180
+ // to avoid additional dynamic allocations.
181
+ void applyInplace(std::vector<Tensor>& physical_tensors) const;
182
+
183
+ std::bitset<kVmapNumLevels> levels_;
184
+ };
185
+
186
+
187
+ } // namespace at::functorch
pythonProject/.venv/Lib/site-packages/torch/include/ATen/functorch/Macros.h ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #pragma once
2
+
3
+ #define SINGLE_ARG(...) __VA_ARGS__