BryanW commited on
Commit
0eaabec
·
verified ·
1 Parent(s): 6c54293

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/__init__.cpython-312.pyc +0 -0
  2. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/impl.cpython-312.pyc +0 -0
  3. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/utils.cpython-312.pyc +0 -0
  4. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/__init__.cpython-312.pyc +0 -0
  5. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/__main__.cpython-312.pyc +0 -0
  6. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/api.cpython-312.pyc +0 -0
  7. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/cd.cpython-312.pyc +0 -0
  8. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/constant.cpython-312.pyc +0 -0
  9. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/legacy.cpython-312.pyc +0 -0
  10. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/md.cpython-312.pyc +0 -0
  11. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/models.cpython-312.pyc +0 -0
  12. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/utils.cpython-312.pyc +0 -0
  13. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/version.cpython-312.pyc +0 -0
  14. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__init__.py +8 -0
  15. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__main__.py +321 -0
  16. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__pycache__/__init__.cpython-312.pyc +0 -0
  17. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__pycache__/__main__.cpython-312.pyc +0 -0
  18. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/__init__.cpython-312.pyc +0 -0
  19. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/_convertions.cpython-312.pyc +0 -0
  20. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/_inspect.cpython-312.pyc +0 -0
  21. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/_pep440.cpython-312.pyc +0 -0
  22. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/__pycache__/__init__.cpython-312.pyc +0 -0
  23. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__init__.py +0 -0
  24. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__pycache__/__init__.cpython-312.pyc +0 -0
  25. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__pycache__/test_deprecations.cpython-312.pyc +0 -0
  26. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__pycache__/test_regression.cpython-312.pyc +0 -0
  27. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/test_deprecations.py +20 -0
  28. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/test_linalg.py +2198 -0
  29. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/test_regression.py +145 -0
  30. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/arithmetic.cpython-312.pyc +0 -0
  31. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayterator.cpython-312.pyc +0 -0
  32. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/bitwise_ops.cpython-312.pyc +0 -0
  33. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/einsumfunc.cpython-312.pyc +0 -0
  34. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_utils.cpython-312.pyc +0 -0
  35. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/literal.cpython-312.pyc +0 -0
  36. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/multiarray.cpython-312.pyc +0 -0
  37. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/numeric.cpython-312.pyc +0 -0
  38. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/simple_py3.cpython-312.pyc +0 -0
  39. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/ufuncs.cpython-312.pyc +0 -0
  40. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/warnings_and_errors.cpython-312.pyc +0 -0
  41. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_convolution_double_backward.h +53 -0
  42. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/cudnn_convolution_relu_ops.h +45 -0
  43. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_ops.h +45 -0
  44. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/silu_meta.h +32 -0
  45. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h +29 -0
  46. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h +29 -0
  47. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/any.h +154 -0
  48. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set_inl.h +281 -0
  49. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/field_mask.pb.h +320 -0
  50. Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_table_driven.h +344 -0
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (538 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/impl.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/aiohappyeyeballs/__pycache__/utils.cpython-312.pyc ADDED
Binary file (3.75 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.81 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/__main__.cpython-312.pyc ADDED
Binary file (381 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/api.cpython-312.pyc ADDED
Binary file (18.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/cd.cpython-312.pyc ADDED
Binary file (13.3 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/constant.cpython-312.pyc ADDED
Binary file (38.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/legacy.cpython-312.pyc ADDED
Binary file (2.85 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/md.cpython-312.pyc ADDED
Binary file (24.4 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/models.cpython-312.pyc ADDED
Binary file (17.2 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/utils.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/__pycache__/version.cpython-312.pyc ADDED
Binary file (408 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from .__main__ import cli_detect, query_yes_no
4
+
5
+ __all__ = (
6
+ "cli_detect",
7
+ "query_yes_no",
8
+ )
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__main__.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import sys
5
+ from json import dumps
6
+ from os.path import abspath, basename, dirname, join, realpath
7
+ from platform import python_version
8
+ from unicodedata import unidata_version
9
+
10
+ import charset_normalizer.md as md_module
11
+ from charset_normalizer import from_fp
12
+ from charset_normalizer.models import CliDetectionResult
13
+ from charset_normalizer.version import __version__
14
+
15
+
16
+ def query_yes_no(question: str, default: str = "yes") -> bool:
17
+ """Ask a yes/no question via input() and return their answer.
18
+
19
+ "question" is a string that is presented to the user.
20
+ "default" is the presumed answer if the user just hits <Enter>.
21
+ It must be "yes" (the default), "no" or None (meaning
22
+ an answer is required of the user).
23
+
24
+ The "answer" return value is True for "yes" or False for "no".
25
+
26
+ Credit goes to (c) https://stackoverflow.com/questions/3041986/apt-command-line-interface-like-yes-no-input
27
+ """
28
+ valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False}
29
+ if default is None:
30
+ prompt = " [y/n] "
31
+ elif default == "yes":
32
+ prompt = " [Y/n] "
33
+ elif default == "no":
34
+ prompt = " [y/N] "
35
+ else:
36
+ raise ValueError("invalid default answer: '%s'" % default)
37
+
38
+ while True:
39
+ sys.stdout.write(question + prompt)
40
+ choice = input().lower()
41
+ if default is not None and choice == "":
42
+ return valid[default]
43
+ elif choice in valid:
44
+ return valid[choice]
45
+ else:
46
+ sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
47
+
48
+
49
+ def cli_detect(argv: list[str] | None = None) -> int:
50
+ """
51
+ CLI assistant using ARGV and ArgumentParser
52
+ :param argv:
53
+ :return: 0 if everything is fine, anything else equal trouble
54
+ """
55
+ parser = argparse.ArgumentParser(
56
+ description="The Real First Universal Charset Detector. "
57
+ "Discover originating encoding used on text file. "
58
+ "Normalize text to unicode."
59
+ )
60
+
61
+ parser.add_argument(
62
+ "files", type=argparse.FileType("rb"), nargs="+", help="File(s) to be analysed"
63
+ )
64
+ parser.add_argument(
65
+ "-v",
66
+ "--verbose",
67
+ action="store_true",
68
+ default=False,
69
+ dest="verbose",
70
+ help="Display complementary information about file if any. "
71
+ "Stdout will contain logs about the detection process.",
72
+ )
73
+ parser.add_argument(
74
+ "-a",
75
+ "--with-alternative",
76
+ action="store_true",
77
+ default=False,
78
+ dest="alternatives",
79
+ help="Output complementary possibilities if any. Top-level JSON WILL be a list.",
80
+ )
81
+ parser.add_argument(
82
+ "-n",
83
+ "--normalize",
84
+ action="store_true",
85
+ default=False,
86
+ dest="normalize",
87
+ help="Permit to normalize input file. If not set, program does not write anything.",
88
+ )
89
+ parser.add_argument(
90
+ "-m",
91
+ "--minimal",
92
+ action="store_true",
93
+ default=False,
94
+ dest="minimal",
95
+ help="Only output the charset detected to STDOUT. Disabling JSON output.",
96
+ )
97
+ parser.add_argument(
98
+ "-r",
99
+ "--replace",
100
+ action="store_true",
101
+ default=False,
102
+ dest="replace",
103
+ help="Replace file when trying to normalize it instead of creating a new one.",
104
+ )
105
+ parser.add_argument(
106
+ "-f",
107
+ "--force",
108
+ action="store_true",
109
+ default=False,
110
+ dest="force",
111
+ help="Replace file without asking if you are sure, use this flag with caution.",
112
+ )
113
+ parser.add_argument(
114
+ "-i",
115
+ "--no-preemptive",
116
+ action="store_true",
117
+ default=False,
118
+ dest="no_preemptive",
119
+ help="Disable looking at a charset declaration to hint the detector.",
120
+ )
121
+ parser.add_argument(
122
+ "-t",
123
+ "--threshold",
124
+ action="store",
125
+ default=0.2,
126
+ type=float,
127
+ dest="threshold",
128
+ help="Define a custom maximum amount of noise allowed in decoded content. 0. <= noise <= 1.",
129
+ )
130
+ parser.add_argument(
131
+ "--version",
132
+ action="version",
133
+ version="Charset-Normalizer {} - Python {} - Unicode {} - SpeedUp {}".format(
134
+ __version__,
135
+ python_version(),
136
+ unidata_version,
137
+ "OFF" if md_module.__file__.lower().endswith(".py") else "ON",
138
+ ),
139
+ help="Show version information and exit.",
140
+ )
141
+
142
+ args = parser.parse_args(argv)
143
+
144
+ if args.replace is True and args.normalize is False:
145
+ if args.files:
146
+ for my_file in args.files:
147
+ my_file.close()
148
+ print("Use --replace in addition of --normalize only.", file=sys.stderr)
149
+ return 1
150
+
151
+ if args.force is True and args.replace is False:
152
+ if args.files:
153
+ for my_file in args.files:
154
+ my_file.close()
155
+ print("Use --force in addition of --replace only.", file=sys.stderr)
156
+ return 1
157
+
158
+ if args.threshold < 0.0 or args.threshold > 1.0:
159
+ if args.files:
160
+ for my_file in args.files:
161
+ my_file.close()
162
+ print("--threshold VALUE should be between 0. AND 1.", file=sys.stderr)
163
+ return 1
164
+
165
+ x_ = []
166
+
167
+ for my_file in args.files:
168
+ matches = from_fp(
169
+ my_file,
170
+ threshold=args.threshold,
171
+ explain=args.verbose,
172
+ preemptive_behaviour=args.no_preemptive is False,
173
+ )
174
+
175
+ best_guess = matches.best()
176
+
177
+ if best_guess is None:
178
+ print(
179
+ 'Unable to identify originating encoding for "{}". {}'.format(
180
+ my_file.name,
181
+ (
182
+ "Maybe try increasing maximum amount of chaos."
183
+ if args.threshold < 1.0
184
+ else ""
185
+ ),
186
+ ),
187
+ file=sys.stderr,
188
+ )
189
+ x_.append(
190
+ CliDetectionResult(
191
+ abspath(my_file.name),
192
+ None,
193
+ [],
194
+ [],
195
+ "Unknown",
196
+ [],
197
+ False,
198
+ 1.0,
199
+ 0.0,
200
+ None,
201
+ True,
202
+ )
203
+ )
204
+ else:
205
+ x_.append(
206
+ CliDetectionResult(
207
+ abspath(my_file.name),
208
+ best_guess.encoding,
209
+ best_guess.encoding_aliases,
210
+ [
211
+ cp
212
+ for cp in best_guess.could_be_from_charset
213
+ if cp != best_guess.encoding
214
+ ],
215
+ best_guess.language,
216
+ best_guess.alphabets,
217
+ best_guess.bom,
218
+ best_guess.percent_chaos,
219
+ best_guess.percent_coherence,
220
+ None,
221
+ True,
222
+ )
223
+ )
224
+
225
+ if len(matches) > 1 and args.alternatives:
226
+ for el in matches:
227
+ if el != best_guess:
228
+ x_.append(
229
+ CliDetectionResult(
230
+ abspath(my_file.name),
231
+ el.encoding,
232
+ el.encoding_aliases,
233
+ [
234
+ cp
235
+ for cp in el.could_be_from_charset
236
+ if cp != el.encoding
237
+ ],
238
+ el.language,
239
+ el.alphabets,
240
+ el.bom,
241
+ el.percent_chaos,
242
+ el.percent_coherence,
243
+ None,
244
+ False,
245
+ )
246
+ )
247
+
248
+ if args.normalize is True:
249
+ if best_guess.encoding.startswith("utf") is True:
250
+ print(
251
+ '"{}" file does not need to be normalized, as it already came from unicode.'.format(
252
+ my_file.name
253
+ ),
254
+ file=sys.stderr,
255
+ )
256
+ if my_file.closed is False:
257
+ my_file.close()
258
+ continue
259
+
260
+ dir_path = dirname(realpath(my_file.name))
261
+ file_name = basename(realpath(my_file.name))
262
+
263
+ o_: list[str] = file_name.split(".")
264
+
265
+ if args.replace is False:
266
+ o_.insert(-1, best_guess.encoding)
267
+ if my_file.closed is False:
268
+ my_file.close()
269
+ elif (
270
+ args.force is False
271
+ and query_yes_no(
272
+ 'Are you sure to normalize "{}" by replacing it ?'.format(
273
+ my_file.name
274
+ ),
275
+ "no",
276
+ )
277
+ is False
278
+ ):
279
+ if my_file.closed is False:
280
+ my_file.close()
281
+ continue
282
+
283
+ try:
284
+ x_[0].unicode_path = join(dir_path, ".".join(o_))
285
+
286
+ with open(x_[0].unicode_path, "wb") as fp:
287
+ fp.write(best_guess.output())
288
+ except OSError as e:
289
+ print(str(e), file=sys.stderr)
290
+ if my_file.closed is False:
291
+ my_file.close()
292
+ return 2
293
+
294
+ if my_file.closed is False:
295
+ my_file.close()
296
+
297
+ if args.minimal is False:
298
+ print(
299
+ dumps(
300
+ [el.__dict__ for el in x_] if len(x_) > 1 else x_[0].__dict__,
301
+ ensure_ascii=True,
302
+ indent=4,
303
+ )
304
+ )
305
+ else:
306
+ for my_file in args.files:
307
+ print(
308
+ ", ".join(
309
+ [
310
+ el.encoding or "undefined"
311
+ for el in x_
312
+ if el.path == abspath(my_file.name)
313
+ ]
314
+ )
315
+ )
316
+
317
+ return 0
318
+
319
+
320
+ if __name__ == "__main__":
321
+ cli_detect()
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (369 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/charset_normalizer/cli/__pycache__/__main__.cpython-312.pyc ADDED
Binary file (11 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (1.12 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/_convertions.cpython-312.pyc ADDED
Binary file (865 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/_inspect.cpython-312.pyc ADDED
Binary file (9.47 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/_utils/__pycache__/_pep440.cpython-312.pyc ADDED
Binary file (19.1 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.11 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__init__.py ADDED
File without changes
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (224 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__pycache__/test_deprecations.cpython-312.pyc ADDED
Binary file (1.27 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/__pycache__/test_regression.cpython-312.pyc ADDED
Binary file (8.7 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/test_deprecations.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test deprecation and future warnings.
2
+
3
+ """
4
+ import numpy as np
5
+ from numpy.testing import assert_warns
6
+
7
+
8
+ def test_qr_mode_full_future_warning():
9
+ """Check mode='full' FutureWarning.
10
+
11
+ In numpy 1.8 the mode options 'full' and 'economic' in linalg.qr were
12
+ deprecated. The release date will probably be sometime in the summer
13
+ of 2013.
14
+
15
+ """
16
+ a = np.eye(2)
17
+ assert_warns(DeprecationWarning, np.linalg.qr, a, mode='full')
18
+ assert_warns(DeprecationWarning, np.linalg.qr, a, mode='f')
19
+ assert_warns(DeprecationWarning, np.linalg.qr, a, mode='economic')
20
+ assert_warns(DeprecationWarning, np.linalg.qr, a, mode='e')
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/test_linalg.py ADDED
@@ -0,0 +1,2198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Test functions for linalg module
2
+
3
+ """
4
+ import os
5
+ import sys
6
+ import itertools
7
+ import traceback
8
+ import textwrap
9
+ import subprocess
10
+ import pytest
11
+
12
+ import numpy as np
13
+ from numpy import array, single, double, csingle, cdouble, dot, identity, matmul
14
+ from numpy.core import swapaxes
15
+ from numpy import multiply, atleast_2d, inf, asarray
16
+ from numpy import linalg
17
+ from numpy.linalg import matrix_power, norm, matrix_rank, multi_dot, LinAlgError
18
+ from numpy.linalg.linalg import _multi_dot_matrix_chain_order
19
+ from numpy.testing import (
20
+ assert_, assert_equal, assert_raises, assert_array_equal,
21
+ assert_almost_equal, assert_allclose, suppress_warnings,
22
+ assert_raises_regex, HAS_LAPACK64, IS_WASM
23
+ )
24
+ try:
25
+ import numpy.linalg.lapack_lite
26
+ except ImportError:
27
+ # May be broken when numpy was built without BLAS/LAPACK present
28
+ # If so, ensure we don't break the whole test suite - the `lapack_lite`
29
+ # submodule should be removed, it's only used in two tests in this file.
30
+ pass
31
+
32
+
33
+ def consistent_subclass(out, in_):
34
+ # For ndarray subclass input, our output should have the same subclass
35
+ # (non-ndarray input gets converted to ndarray).
36
+ return type(out) is (type(in_) if isinstance(in_, np.ndarray)
37
+ else np.ndarray)
38
+
39
+
40
+ old_assert_almost_equal = assert_almost_equal
41
+
42
+
43
+ def assert_almost_equal(a, b, single_decimal=6, double_decimal=12, **kw):
44
+ if asarray(a).dtype.type in (single, csingle):
45
+ decimal = single_decimal
46
+ else:
47
+ decimal = double_decimal
48
+ old_assert_almost_equal(a, b, decimal=decimal, **kw)
49
+
50
+
51
+ def get_real_dtype(dtype):
52
+ return {single: single, double: double,
53
+ csingle: single, cdouble: double}[dtype]
54
+
55
+
56
+ def get_complex_dtype(dtype):
57
+ return {single: csingle, double: cdouble,
58
+ csingle: csingle, cdouble: cdouble}[dtype]
59
+
60
+
61
+ def get_rtol(dtype):
62
+ # Choose a safe rtol
63
+ if dtype in (single, csingle):
64
+ return 1e-5
65
+ else:
66
+ return 1e-11
67
+
68
+
69
+ # used to categorize tests
70
+ all_tags = {
71
+ 'square', 'nonsquare', 'hermitian', # mutually exclusive
72
+ 'generalized', 'size-0', 'strided' # optional additions
73
+ }
74
+
75
+
76
+ class LinalgCase:
77
+ def __init__(self, name, a, b, tags=set()):
78
+ """
79
+ A bundle of arguments to be passed to a test case, with an identifying
80
+ name, the operands a and b, and a set of tags to filter the tests
81
+ """
82
+ assert_(isinstance(name, str))
83
+ self.name = name
84
+ self.a = a
85
+ self.b = b
86
+ self.tags = frozenset(tags) # prevent shared tags
87
+
88
+ def check(self, do):
89
+ """
90
+ Run the function `do` on this test case, expanding arguments
91
+ """
92
+ do(self.a, self.b, tags=self.tags)
93
+
94
+ def __repr__(self):
95
+ return f'<LinalgCase: {self.name}>'
96
+
97
+
98
+ def apply_tag(tag, cases):
99
+ """
100
+ Add the given tag (a string) to each of the cases (a list of LinalgCase
101
+ objects)
102
+ """
103
+ assert tag in all_tags, "Invalid tag"
104
+ for case in cases:
105
+ case.tags = case.tags | {tag}
106
+ return cases
107
+
108
+
109
+ #
110
+ # Base test cases
111
+ #
112
+
113
+ np.random.seed(1234)
114
+
115
+ CASES = []
116
+
117
+ # square test cases
118
+ CASES += apply_tag('square', [
119
+ LinalgCase("single",
120
+ array([[1., 2.], [3., 4.]], dtype=single),
121
+ array([2., 1.], dtype=single)),
122
+ LinalgCase("double",
123
+ array([[1., 2.], [3., 4.]], dtype=double),
124
+ array([2., 1.], dtype=double)),
125
+ LinalgCase("double_2",
126
+ array([[1., 2.], [3., 4.]], dtype=double),
127
+ array([[2., 1., 4.], [3., 4., 6.]], dtype=double)),
128
+ LinalgCase("csingle",
129
+ array([[1. + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=csingle),
130
+ array([2. + 1j, 1. + 2j], dtype=csingle)),
131
+ LinalgCase("cdouble",
132
+ array([[1. + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=cdouble),
133
+ array([2. + 1j, 1. + 2j], dtype=cdouble)),
134
+ LinalgCase("cdouble_2",
135
+ array([[1. + 2j, 2 + 3j], [3 + 4j, 4 + 5j]], dtype=cdouble),
136
+ array([[2. + 1j, 1. + 2j, 1 + 3j], [1 - 2j, 1 - 3j, 1 - 6j]], dtype=cdouble)),
137
+ LinalgCase("0x0",
138
+ np.empty((0, 0), dtype=double),
139
+ np.empty((0,), dtype=double),
140
+ tags={'size-0'}),
141
+ LinalgCase("8x8",
142
+ np.random.rand(8, 8),
143
+ np.random.rand(8)),
144
+ LinalgCase("1x1",
145
+ np.random.rand(1, 1),
146
+ np.random.rand(1)),
147
+ LinalgCase("nonarray",
148
+ [[1, 2], [3, 4]],
149
+ [2, 1]),
150
+ ])
151
+
152
+ # non-square test-cases
153
+ CASES += apply_tag('nonsquare', [
154
+ LinalgCase("single_nsq_1",
155
+ array([[1., 2., 3.], [3., 4., 6.]], dtype=single),
156
+ array([2., 1.], dtype=single)),
157
+ LinalgCase("single_nsq_2",
158
+ array([[1., 2.], [3., 4.], [5., 6.]], dtype=single),
159
+ array([2., 1., 3.], dtype=single)),
160
+ LinalgCase("double_nsq_1",
161
+ array([[1., 2., 3.], [3., 4., 6.]], dtype=double),
162
+ array([2., 1.], dtype=double)),
163
+ LinalgCase("double_nsq_2",
164
+ array([[1., 2.], [3., 4.], [5., 6.]], dtype=double),
165
+ array([2., 1., 3.], dtype=double)),
166
+ LinalgCase("csingle_nsq_1",
167
+ array(
168
+ [[1. + 1j, 2. + 2j, 3. - 3j], [3. - 5j, 4. + 9j, 6. + 2j]], dtype=csingle),
169
+ array([2. + 1j, 1. + 2j], dtype=csingle)),
170
+ LinalgCase("csingle_nsq_2",
171
+ array(
172
+ [[1. + 1j, 2. + 2j], [3. - 3j, 4. - 9j], [5. - 4j, 6. + 8j]], dtype=csingle),
173
+ array([2. + 1j, 1. + 2j, 3. - 3j], dtype=csingle)),
174
+ LinalgCase("cdouble_nsq_1",
175
+ array(
176
+ [[1. + 1j, 2. + 2j, 3. - 3j], [3. - 5j, 4. + 9j, 6. + 2j]], dtype=cdouble),
177
+ array([2. + 1j, 1. + 2j], dtype=cdouble)),
178
+ LinalgCase("cdouble_nsq_2",
179
+ array(
180
+ [[1. + 1j, 2. + 2j], [3. - 3j, 4. - 9j], [5. - 4j, 6. + 8j]], dtype=cdouble),
181
+ array([2. + 1j, 1. + 2j, 3. - 3j], dtype=cdouble)),
182
+ LinalgCase("cdouble_nsq_1_2",
183
+ array(
184
+ [[1. + 1j, 2. + 2j, 3. - 3j], [3. - 5j, 4. + 9j, 6. + 2j]], dtype=cdouble),
185
+ array([[2. + 1j, 1. + 2j], [1 - 1j, 2 - 2j]], dtype=cdouble)),
186
+ LinalgCase("cdouble_nsq_2_2",
187
+ array(
188
+ [[1. + 1j, 2. + 2j], [3. - 3j, 4. - 9j], [5. - 4j, 6. + 8j]], dtype=cdouble),
189
+ array([[2. + 1j, 1. + 2j], [1 - 1j, 2 - 2j], [1 - 1j, 2 - 2j]], dtype=cdouble)),
190
+ LinalgCase("8x11",
191
+ np.random.rand(8, 11),
192
+ np.random.rand(8)),
193
+ LinalgCase("1x5",
194
+ np.random.rand(1, 5),
195
+ np.random.rand(1)),
196
+ LinalgCase("5x1",
197
+ np.random.rand(5, 1),
198
+ np.random.rand(5)),
199
+ LinalgCase("0x4",
200
+ np.random.rand(0, 4),
201
+ np.random.rand(0),
202
+ tags={'size-0'}),
203
+ LinalgCase("4x0",
204
+ np.random.rand(4, 0),
205
+ np.random.rand(4),
206
+ tags={'size-0'}),
207
+ ])
208
+
209
+ # hermitian test-cases
210
+ CASES += apply_tag('hermitian', [
211
+ LinalgCase("hsingle",
212
+ array([[1., 2.], [2., 1.]], dtype=single),
213
+ None),
214
+ LinalgCase("hdouble",
215
+ array([[1., 2.], [2., 1.]], dtype=double),
216
+ None),
217
+ LinalgCase("hcsingle",
218
+ array([[1., 2 + 3j], [2 - 3j, 1]], dtype=csingle),
219
+ None),
220
+ LinalgCase("hcdouble",
221
+ array([[1., 2 + 3j], [2 - 3j, 1]], dtype=cdouble),
222
+ None),
223
+ LinalgCase("hempty",
224
+ np.empty((0, 0), dtype=double),
225
+ None,
226
+ tags={'size-0'}),
227
+ LinalgCase("hnonarray",
228
+ [[1, 2], [2, 1]],
229
+ None),
230
+ LinalgCase("matrix_b_only",
231
+ array([[1., 2.], [2., 1.]]),
232
+ None),
233
+ LinalgCase("hmatrix_1x1",
234
+ np.random.rand(1, 1),
235
+ None),
236
+ ])
237
+
238
+
239
+ #
240
+ # Gufunc test cases
241
+ #
242
+ def _make_generalized_cases():
243
+ new_cases = []
244
+
245
+ for case in CASES:
246
+ if not isinstance(case.a, np.ndarray):
247
+ continue
248
+
249
+ a = np.array([case.a, 2 * case.a, 3 * case.a])
250
+ if case.b is None:
251
+ b = None
252
+ else:
253
+ b = np.array([case.b, 7 * case.b, 6 * case.b])
254
+ new_case = LinalgCase(case.name + "_tile3", a, b,
255
+ tags=case.tags | {'generalized'})
256
+ new_cases.append(new_case)
257
+
258
+ a = np.array([case.a] * 2 * 3).reshape((3, 2) + case.a.shape)
259
+ if case.b is None:
260
+ b = None
261
+ else:
262
+ b = np.array([case.b] * 2 * 3).reshape((3, 2) + case.b.shape)
263
+ new_case = LinalgCase(case.name + "_tile213", a, b,
264
+ tags=case.tags | {'generalized'})
265
+ new_cases.append(new_case)
266
+
267
+ return new_cases
268
+
269
+
270
+ CASES += _make_generalized_cases()
271
+
272
+
273
+ #
274
+ # Generate stride combination variations of the above
275
+ #
276
+ def _stride_comb_iter(x):
277
+ """
278
+ Generate cartesian product of strides for all axes
279
+ """
280
+
281
+ if not isinstance(x, np.ndarray):
282
+ yield x, "nop"
283
+ return
284
+
285
+ stride_set = [(1,)] * x.ndim
286
+ stride_set[-1] = (1, 3, -4)
287
+ if x.ndim > 1:
288
+ stride_set[-2] = (1, 3, -4)
289
+ if x.ndim > 2:
290
+ stride_set[-3] = (1, -4)
291
+
292
+ for repeats in itertools.product(*tuple(stride_set)):
293
+ new_shape = [abs(a * b) for a, b in zip(x.shape, repeats)]
294
+ slices = tuple([slice(None, None, repeat) for repeat in repeats])
295
+
296
+ # new array with different strides, but same data
297
+ xi = np.empty(new_shape, dtype=x.dtype)
298
+ xi.view(np.uint32).fill(0xdeadbeef)
299
+ xi = xi[slices]
300
+ xi[...] = x
301
+ xi = xi.view(x.__class__)
302
+ assert_(np.all(xi == x))
303
+ yield xi, "stride_" + "_".join(["%+d" % j for j in repeats])
304
+
305
+ # generate also zero strides if possible
306
+ if x.ndim >= 1 and x.shape[-1] == 1:
307
+ s = list(x.strides)
308
+ s[-1] = 0
309
+ xi = np.lib.stride_tricks.as_strided(x, strides=s)
310
+ yield xi, "stride_xxx_0"
311
+ if x.ndim >= 2 and x.shape[-2] == 1:
312
+ s = list(x.strides)
313
+ s[-2] = 0
314
+ xi = np.lib.stride_tricks.as_strided(x, strides=s)
315
+ yield xi, "stride_xxx_0_x"
316
+ if x.ndim >= 2 and x.shape[:-2] == (1, 1):
317
+ s = list(x.strides)
318
+ s[-1] = 0
319
+ s[-2] = 0
320
+ xi = np.lib.stride_tricks.as_strided(x, strides=s)
321
+ yield xi, "stride_xxx_0_0"
322
+
323
+
324
+ def _make_strided_cases():
325
+ new_cases = []
326
+ for case in CASES:
327
+ for a, a_label in _stride_comb_iter(case.a):
328
+ for b, b_label in _stride_comb_iter(case.b):
329
+ new_case = LinalgCase(case.name + "_" + a_label + "_" + b_label, a, b,
330
+ tags=case.tags | {'strided'})
331
+ new_cases.append(new_case)
332
+ return new_cases
333
+
334
+
335
+ CASES += _make_strided_cases()
336
+
337
+
338
+ #
339
+ # Test different routines against the above cases
340
+ #
341
+ class LinalgTestCase:
342
+ TEST_CASES = CASES
343
+
344
+ def check_cases(self, require=set(), exclude=set()):
345
+ """
346
+ Run func on each of the cases with all of the tags in require, and none
347
+ of the tags in exclude
348
+ """
349
+ for case in self.TEST_CASES:
350
+ # filter by require and exclude
351
+ if case.tags & require != require:
352
+ continue
353
+ if case.tags & exclude:
354
+ continue
355
+
356
+ try:
357
+ case.check(self.do)
358
+ except Exception as e:
359
+ msg = f'In test case: {case!r}\n\n'
360
+ msg += traceback.format_exc()
361
+ raise AssertionError(msg) from e
362
+
363
+
364
+ class LinalgSquareTestCase(LinalgTestCase):
365
+
366
+ def test_sq_cases(self):
367
+ self.check_cases(require={'square'},
368
+ exclude={'generalized', 'size-0'})
369
+
370
+ def test_empty_sq_cases(self):
371
+ self.check_cases(require={'square', 'size-0'},
372
+ exclude={'generalized'})
373
+
374
+
375
+ class LinalgNonsquareTestCase(LinalgTestCase):
376
+
377
+ def test_nonsq_cases(self):
378
+ self.check_cases(require={'nonsquare'},
379
+ exclude={'generalized', 'size-0'})
380
+
381
+ def test_empty_nonsq_cases(self):
382
+ self.check_cases(require={'nonsquare', 'size-0'},
383
+ exclude={'generalized'})
384
+
385
+
386
+ class HermitianTestCase(LinalgTestCase):
387
+
388
+ def test_herm_cases(self):
389
+ self.check_cases(require={'hermitian'},
390
+ exclude={'generalized', 'size-0'})
391
+
392
+ def test_empty_herm_cases(self):
393
+ self.check_cases(require={'hermitian', 'size-0'},
394
+ exclude={'generalized'})
395
+
396
+
397
+ class LinalgGeneralizedSquareTestCase(LinalgTestCase):
398
+
399
+ @pytest.mark.slow
400
+ def test_generalized_sq_cases(self):
401
+ self.check_cases(require={'generalized', 'square'},
402
+ exclude={'size-0'})
403
+
404
+ @pytest.mark.slow
405
+ def test_generalized_empty_sq_cases(self):
406
+ self.check_cases(require={'generalized', 'square', 'size-0'})
407
+
408
+
409
+ class LinalgGeneralizedNonsquareTestCase(LinalgTestCase):
410
+
411
+ @pytest.mark.slow
412
+ def test_generalized_nonsq_cases(self):
413
+ self.check_cases(require={'generalized', 'nonsquare'},
414
+ exclude={'size-0'})
415
+
416
+ @pytest.mark.slow
417
+ def test_generalized_empty_nonsq_cases(self):
418
+ self.check_cases(require={'generalized', 'nonsquare', 'size-0'})
419
+
420
+
421
+ class HermitianGeneralizedTestCase(LinalgTestCase):
422
+
423
+ @pytest.mark.slow
424
+ def test_generalized_herm_cases(self):
425
+ self.check_cases(require={'generalized', 'hermitian'},
426
+ exclude={'size-0'})
427
+
428
+ @pytest.mark.slow
429
+ def test_generalized_empty_herm_cases(self):
430
+ self.check_cases(require={'generalized', 'hermitian', 'size-0'},
431
+ exclude={'none'})
432
+
433
+
434
+ def dot_generalized(a, b):
435
+ a = asarray(a)
436
+ if a.ndim >= 3:
437
+ if a.ndim == b.ndim:
438
+ # matrix x matrix
439
+ new_shape = a.shape[:-1] + b.shape[-1:]
440
+ elif a.ndim == b.ndim + 1:
441
+ # matrix x vector
442
+ new_shape = a.shape[:-1]
443
+ else:
444
+ raise ValueError("Not implemented...")
445
+ r = np.empty(new_shape, dtype=np.common_type(a, b))
446
+ for c in itertools.product(*map(range, a.shape[:-2])):
447
+ r[c] = dot(a[c], b[c])
448
+ return r
449
+ else:
450
+ return dot(a, b)
451
+
452
+
453
+ def identity_like_generalized(a):
454
+ a = asarray(a)
455
+ if a.ndim >= 3:
456
+ r = np.empty(a.shape, dtype=a.dtype)
457
+ r[...] = identity(a.shape[-2])
458
+ return r
459
+ else:
460
+ return identity(a.shape[0])
461
+
462
+
463
+ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
464
+ # kept apart from TestSolve for use for testing with matrices.
465
+ def do(self, a, b, tags):
466
+ x = linalg.solve(a, b)
467
+ assert_almost_equal(b, dot_generalized(a, x))
468
+ assert_(consistent_subclass(x, b))
469
+
470
+
471
+ class TestSolve(SolveCases):
472
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
473
+ def test_types(self, dtype):
474
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
475
+ assert_equal(linalg.solve(x, x).dtype, dtype)
476
+
477
+ def test_0_size(self):
478
+ class ArraySubclass(np.ndarray):
479
+ pass
480
+ # Test system of 0x0 matrices
481
+ a = np.arange(8).reshape(2, 2, 2)
482
+ b = np.arange(6).reshape(1, 2, 3).view(ArraySubclass)
483
+
484
+ expected = linalg.solve(a, b)[:, 0:0, :]
485
+ result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0, :])
486
+ assert_array_equal(result, expected)
487
+ assert_(isinstance(result, ArraySubclass))
488
+
489
+ # Test errors for non-square and only b's dimension being 0
490
+ assert_raises(linalg.LinAlgError, linalg.solve, a[:, 0:0, 0:1], b)
491
+ assert_raises(ValueError, linalg.solve, a, b[:, 0:0, :])
492
+
493
+ # Test broadcasting error
494
+ b = np.arange(6).reshape(1, 3, 2) # broadcasting error
495
+ assert_raises(ValueError, linalg.solve, a, b)
496
+ assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
497
+
498
+ # Test zero "single equations" with 0x0 matrices.
499
+ b = np.arange(2).reshape(1, 2).view(ArraySubclass)
500
+ expected = linalg.solve(a, b)[:, 0:0]
501
+ result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0])
502
+ assert_array_equal(result, expected)
503
+ assert_(isinstance(result, ArraySubclass))
504
+
505
+ b = np.arange(3).reshape(1, 3)
506
+ assert_raises(ValueError, linalg.solve, a, b)
507
+ assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
508
+ assert_raises(ValueError, linalg.solve, a[:, 0:0, 0:0], b)
509
+
510
+ def test_0_size_k(self):
511
+ # test zero multiple equation (K=0) case.
512
+ class ArraySubclass(np.ndarray):
513
+ pass
514
+ a = np.arange(4).reshape(1, 2, 2)
515
+ b = np.arange(6).reshape(3, 2, 1).view(ArraySubclass)
516
+
517
+ expected = linalg.solve(a, b)[:, :, 0:0]
518
+ result = linalg.solve(a, b[:, :, 0:0])
519
+ assert_array_equal(result, expected)
520
+ assert_(isinstance(result, ArraySubclass))
521
+
522
+ # test both zero.
523
+ expected = linalg.solve(a, b)[:, 0:0, 0:0]
524
+ result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0, 0:0])
525
+ assert_array_equal(result, expected)
526
+ assert_(isinstance(result, ArraySubclass))
527
+
528
+
529
+ class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
530
+
531
+ def do(self, a, b, tags):
532
+ a_inv = linalg.inv(a)
533
+ assert_almost_equal(dot_generalized(a, a_inv),
534
+ identity_like_generalized(a))
535
+ assert_(consistent_subclass(a_inv, a))
536
+
537
+
538
+ class TestInv(InvCases):
539
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
540
+ def test_types(self, dtype):
541
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
542
+ assert_equal(linalg.inv(x).dtype, dtype)
543
+
544
+ def test_0_size(self):
545
+ # Check that all kinds of 0-sized arrays work
546
+ class ArraySubclass(np.ndarray):
547
+ pass
548
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
549
+ res = linalg.inv(a)
550
+ assert_(res.dtype.type is np.float64)
551
+ assert_equal(a.shape, res.shape)
552
+ assert_(isinstance(res, ArraySubclass))
553
+
554
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
555
+ res = linalg.inv(a)
556
+ assert_(res.dtype.type is np.complex64)
557
+ assert_equal(a.shape, res.shape)
558
+ assert_(isinstance(res, ArraySubclass))
559
+
560
+
561
+ class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
562
+
563
+ def do(self, a, b, tags):
564
+ ev = linalg.eigvals(a)
565
+ evalues, evectors = linalg.eig(a)
566
+ assert_almost_equal(ev, evalues)
567
+
568
+
569
+ class TestEigvals(EigvalsCases):
570
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
571
+ def test_types(self, dtype):
572
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
573
+ assert_equal(linalg.eigvals(x).dtype, dtype)
574
+ x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
575
+ assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
576
+
577
+ def test_0_size(self):
578
+ # Check that all kinds of 0-sized arrays work
579
+ class ArraySubclass(np.ndarray):
580
+ pass
581
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
582
+ res = linalg.eigvals(a)
583
+ assert_(res.dtype.type is np.float64)
584
+ assert_equal((0, 1), res.shape)
585
+ # This is just for documentation, it might make sense to change:
586
+ assert_(isinstance(res, np.ndarray))
587
+
588
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
589
+ res = linalg.eigvals(a)
590
+ assert_(res.dtype.type is np.complex64)
591
+ assert_equal((0,), res.shape)
592
+ # This is just for documentation, it might make sense to change:
593
+ assert_(isinstance(res, np.ndarray))
594
+
595
+
596
+ class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
597
+
598
+ def do(self, a, b, tags):
599
+ res = linalg.eig(a)
600
+ eigenvalues, eigenvectors = res.eigenvalues, res.eigenvectors
601
+ assert_allclose(dot_generalized(a, eigenvectors),
602
+ np.asarray(eigenvectors) * np.asarray(eigenvalues)[..., None, :],
603
+ rtol=get_rtol(eigenvalues.dtype))
604
+ assert_(consistent_subclass(eigenvectors, a))
605
+
606
+
607
+ class TestEig(EigCases):
608
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
609
+ def test_types(self, dtype):
610
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
611
+ w, v = np.linalg.eig(x)
612
+ assert_equal(w.dtype, dtype)
613
+ assert_equal(v.dtype, dtype)
614
+
615
+ x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
616
+ w, v = np.linalg.eig(x)
617
+ assert_equal(w.dtype, get_complex_dtype(dtype))
618
+ assert_equal(v.dtype, get_complex_dtype(dtype))
619
+
620
+ def test_0_size(self):
621
+ # Check that all kinds of 0-sized arrays work
622
+ class ArraySubclass(np.ndarray):
623
+ pass
624
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
625
+ res, res_v = linalg.eig(a)
626
+ assert_(res_v.dtype.type is np.float64)
627
+ assert_(res.dtype.type is np.float64)
628
+ assert_equal(a.shape, res_v.shape)
629
+ assert_equal((0, 1), res.shape)
630
+ # This is just for documentation, it might make sense to change:
631
+ assert_(isinstance(a, np.ndarray))
632
+
633
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
634
+ res, res_v = linalg.eig(a)
635
+ assert_(res_v.dtype.type is np.complex64)
636
+ assert_(res.dtype.type is np.complex64)
637
+ assert_equal(a.shape, res_v.shape)
638
+ assert_equal((0,), res.shape)
639
+ # This is just for documentation, it might make sense to change:
640
+ assert_(isinstance(a, np.ndarray))
641
+
642
+
643
+ class SVDBaseTests:
644
+ hermitian = False
645
+
646
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
647
+ def test_types(self, dtype):
648
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
649
+ res = linalg.svd(x)
650
+ U, S, Vh = res.U, res.S, res.Vh
651
+ assert_equal(U.dtype, dtype)
652
+ assert_equal(S.dtype, get_real_dtype(dtype))
653
+ assert_equal(Vh.dtype, dtype)
654
+ s = linalg.svd(x, compute_uv=False, hermitian=self.hermitian)
655
+ assert_equal(s.dtype, get_real_dtype(dtype))
656
+
657
+
658
+ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
659
+
660
+ def do(self, a, b, tags):
661
+ u, s, vt = linalg.svd(a, False)
662
+ assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
663
+ np.asarray(vt)),
664
+ rtol=get_rtol(u.dtype))
665
+ assert_(consistent_subclass(u, a))
666
+ assert_(consistent_subclass(vt, a))
667
+
668
+
669
+ class TestSVD(SVDCases, SVDBaseTests):
670
+ def test_empty_identity(self):
671
+ """ Empty input should put an identity matrix in u or vh """
672
+ x = np.empty((4, 0))
673
+ u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
674
+ assert_equal(u.shape, (4, 4))
675
+ assert_equal(vh.shape, (0, 0))
676
+ assert_equal(u, np.eye(4))
677
+
678
+ x = np.empty((0, 4))
679
+ u, s, vh = linalg.svd(x, compute_uv=True, hermitian=self.hermitian)
680
+ assert_equal(u.shape, (0, 0))
681
+ assert_equal(vh.shape, (4, 4))
682
+ assert_equal(vh, np.eye(4))
683
+
684
+
685
+ class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
686
+
687
+ def do(self, a, b, tags):
688
+ u, s, vt = linalg.svd(a, False, hermitian=True)
689
+ assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
690
+ np.asarray(vt)),
691
+ rtol=get_rtol(u.dtype))
692
+ def hermitian(mat):
693
+ axes = list(range(mat.ndim))
694
+ axes[-1], axes[-2] = axes[-2], axes[-1]
695
+ return np.conj(np.transpose(mat, axes=axes))
696
+
697
+ assert_almost_equal(np.matmul(u, hermitian(u)), np.broadcast_to(np.eye(u.shape[-1]), u.shape))
698
+ assert_almost_equal(np.matmul(vt, hermitian(vt)), np.broadcast_to(np.eye(vt.shape[-1]), vt.shape))
699
+ assert_equal(np.sort(s)[..., ::-1], s)
700
+ assert_(consistent_subclass(u, a))
701
+ assert_(consistent_subclass(vt, a))
702
+
703
+
704
+ class TestSVDHermitian(SVDHermitianCases, SVDBaseTests):
705
+ hermitian = True
706
+
707
+
708
+ class CondCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
709
+ # cond(x, p) for p in (None, 2, -2)
710
+
711
+ def do(self, a, b, tags):
712
+ c = asarray(a) # a might be a matrix
713
+ if 'size-0' in tags:
714
+ assert_raises(LinAlgError, linalg.cond, c)
715
+ return
716
+
717
+ # +-2 norms
718
+ s = linalg.svd(c, compute_uv=False)
719
+ assert_almost_equal(
720
+ linalg.cond(a), s[..., 0] / s[..., -1],
721
+ single_decimal=5, double_decimal=11)
722
+ assert_almost_equal(
723
+ linalg.cond(a, 2), s[..., 0] / s[..., -1],
724
+ single_decimal=5, double_decimal=11)
725
+ assert_almost_equal(
726
+ linalg.cond(a, -2), s[..., -1] / s[..., 0],
727
+ single_decimal=5, double_decimal=11)
728
+
729
+ # Other norms
730
+ cinv = np.linalg.inv(c)
731
+ assert_almost_equal(
732
+ linalg.cond(a, 1),
733
+ abs(c).sum(-2).max(-1) * abs(cinv).sum(-2).max(-1),
734
+ single_decimal=5, double_decimal=11)
735
+ assert_almost_equal(
736
+ linalg.cond(a, -1),
737
+ abs(c).sum(-2).min(-1) * abs(cinv).sum(-2).min(-1),
738
+ single_decimal=5, double_decimal=11)
739
+ assert_almost_equal(
740
+ linalg.cond(a, np.inf),
741
+ abs(c).sum(-1).max(-1) * abs(cinv).sum(-1).max(-1),
742
+ single_decimal=5, double_decimal=11)
743
+ assert_almost_equal(
744
+ linalg.cond(a, -np.inf),
745
+ abs(c).sum(-1).min(-1) * abs(cinv).sum(-1).min(-1),
746
+ single_decimal=5, double_decimal=11)
747
+ assert_almost_equal(
748
+ linalg.cond(a, 'fro'),
749
+ np.sqrt((abs(c)**2).sum(-1).sum(-1)
750
+ * (abs(cinv)**2).sum(-1).sum(-1)),
751
+ single_decimal=5, double_decimal=11)
752
+
753
+
754
+ class TestCond(CondCases):
755
+ def test_basic_nonsvd(self):
756
+ # Smoketest the non-svd norms
757
+ A = array([[1., 0, 1], [0, -2., 0], [0, 0, 3.]])
758
+ assert_almost_equal(linalg.cond(A, inf), 4)
759
+ assert_almost_equal(linalg.cond(A, -inf), 2/3)
760
+ assert_almost_equal(linalg.cond(A, 1), 4)
761
+ assert_almost_equal(linalg.cond(A, -1), 0.5)
762
+ assert_almost_equal(linalg.cond(A, 'fro'), np.sqrt(265 / 12))
763
+
764
+ def test_singular(self):
765
+ # Singular matrices have infinite condition number for
766
+ # positive norms, and negative norms shouldn't raise
767
+ # exceptions
768
+ As = [np.zeros((2, 2)), np.ones((2, 2))]
769
+ p_pos = [None, 1, 2, 'fro']
770
+ p_neg = [-1, -2]
771
+ for A, p in itertools.product(As, p_pos):
772
+ # Inversion may not hit exact infinity, so just check the
773
+ # number is large
774
+ assert_(linalg.cond(A, p) > 1e15)
775
+ for A, p in itertools.product(As, p_neg):
776
+ linalg.cond(A, p)
777
+
778
+ @pytest.mark.xfail(True, run=False,
779
+ reason="Platform/LAPACK-dependent failure, "
780
+ "see gh-18914")
781
+ def test_nan(self):
782
+ # nans should be passed through, not converted to infs
783
+ ps = [None, 1, -1, 2, -2, 'fro']
784
+ p_pos = [None, 1, 2, 'fro']
785
+
786
+ A = np.ones((2, 2))
787
+ A[0,1] = np.nan
788
+ for p in ps:
789
+ c = linalg.cond(A, p)
790
+ assert_(isinstance(c, np.float_))
791
+ assert_(np.isnan(c))
792
+
793
+ A = np.ones((3, 2, 2))
794
+ A[1,0,1] = np.nan
795
+ for p in ps:
796
+ c = linalg.cond(A, p)
797
+ assert_(np.isnan(c[1]))
798
+ if p in p_pos:
799
+ assert_(c[0] > 1e15)
800
+ assert_(c[2] > 1e15)
801
+ else:
802
+ assert_(not np.isnan(c[0]))
803
+ assert_(not np.isnan(c[2]))
804
+
805
+ def test_stacked_singular(self):
806
+ # Check behavior when only some of the stacked matrices are
807
+ # singular
808
+ np.random.seed(1234)
809
+ A = np.random.rand(2, 2, 2, 2)
810
+ A[0,0] = 0
811
+ A[1,1] = 0
812
+
813
+ for p in (None, 1, 2, 'fro', -1, -2):
814
+ c = linalg.cond(A, p)
815
+ assert_equal(c[0,0], np.inf)
816
+ assert_equal(c[1,1], np.inf)
817
+ assert_(np.isfinite(c[0,1]))
818
+ assert_(np.isfinite(c[1,0]))
819
+
820
+
821
+ class PinvCases(LinalgSquareTestCase,
822
+ LinalgNonsquareTestCase,
823
+ LinalgGeneralizedSquareTestCase,
824
+ LinalgGeneralizedNonsquareTestCase):
825
+
826
+ def do(self, a, b, tags):
827
+ a_ginv = linalg.pinv(a)
828
+ # `a @ a_ginv == I` does not hold if a is singular
829
+ dot = dot_generalized
830
+ assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
831
+ assert_(consistent_subclass(a_ginv, a))
832
+
833
+
834
+ class TestPinv(PinvCases):
835
+ pass
836
+
837
+
838
+ class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
839
+
840
+ def do(self, a, b, tags):
841
+ a_ginv = linalg.pinv(a, hermitian=True)
842
+ # `a @ a_ginv == I` does not hold if a is singular
843
+ dot = dot_generalized
844
+ assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
845
+ assert_(consistent_subclass(a_ginv, a))
846
+
847
+
848
+ class TestPinvHermitian(PinvHermitianCases):
849
+ pass
850
+
851
+
852
+ class DetCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
853
+
854
+ def do(self, a, b, tags):
855
+ d = linalg.det(a)
856
+ res = linalg.slogdet(a)
857
+ s, ld = res.sign, res.logabsdet
858
+ if asarray(a).dtype.type in (single, double):
859
+ ad = asarray(a).astype(double)
860
+ else:
861
+ ad = asarray(a).astype(cdouble)
862
+ ev = linalg.eigvals(ad)
863
+ assert_almost_equal(d, multiply.reduce(ev, axis=-1))
864
+ assert_almost_equal(s * np.exp(ld), multiply.reduce(ev, axis=-1))
865
+
866
+ s = np.atleast_1d(s)
867
+ ld = np.atleast_1d(ld)
868
+ m = (s != 0)
869
+ assert_almost_equal(np.abs(s[m]), 1)
870
+ assert_equal(ld[~m], -inf)
871
+
872
+
873
+ class TestDet(DetCases):
874
+ def test_zero(self):
875
+ assert_equal(linalg.det([[0.0]]), 0.0)
876
+ assert_equal(type(linalg.det([[0.0]])), double)
877
+ assert_equal(linalg.det([[0.0j]]), 0.0)
878
+ assert_equal(type(linalg.det([[0.0j]])), cdouble)
879
+
880
+ assert_equal(linalg.slogdet([[0.0]]), (0.0, -inf))
881
+ assert_equal(type(linalg.slogdet([[0.0]])[0]), double)
882
+ assert_equal(type(linalg.slogdet([[0.0]])[1]), double)
883
+ assert_equal(linalg.slogdet([[0.0j]]), (0.0j, -inf))
884
+ assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble)
885
+ assert_equal(type(linalg.slogdet([[0.0j]])[1]), double)
886
+
887
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
888
+ def test_types(self, dtype):
889
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
890
+ assert_equal(np.linalg.det(x).dtype, dtype)
891
+ ph, s = np.linalg.slogdet(x)
892
+ assert_equal(s.dtype, get_real_dtype(dtype))
893
+ assert_equal(ph.dtype, dtype)
894
+
895
+ def test_0_size(self):
896
+ a = np.zeros((0, 0), dtype=np.complex64)
897
+ res = linalg.det(a)
898
+ assert_equal(res, 1.)
899
+ assert_(res.dtype.type is np.complex64)
900
+ res = linalg.slogdet(a)
901
+ assert_equal(res, (1, 0))
902
+ assert_(res[0].dtype.type is np.complex64)
903
+ assert_(res[1].dtype.type is np.float32)
904
+
905
+ a = np.zeros((0, 0), dtype=np.float64)
906
+ res = linalg.det(a)
907
+ assert_equal(res, 1.)
908
+ assert_(res.dtype.type is np.float64)
909
+ res = linalg.slogdet(a)
910
+ assert_equal(res, (1, 0))
911
+ assert_(res[0].dtype.type is np.float64)
912
+ assert_(res[1].dtype.type is np.float64)
913
+
914
+
915
+ class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase):
916
+
917
+ def do(self, a, b, tags):
918
+ arr = np.asarray(a)
919
+ m, n = arr.shape
920
+ u, s, vt = linalg.svd(a, False)
921
+ x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1)
922
+ if m == 0:
923
+ assert_((x == 0).all())
924
+ if m <= n:
925
+ assert_almost_equal(b, dot(a, x))
926
+ assert_equal(rank, m)
927
+ else:
928
+ assert_equal(rank, n)
929
+ assert_almost_equal(sv, sv.__array_wrap__(s))
930
+ if rank == n and m > n:
931
+ expect_resids = (
932
+ np.asarray(abs(np.dot(a, x) - b)) ** 2).sum(axis=0)
933
+ expect_resids = np.asarray(expect_resids)
934
+ if np.asarray(b).ndim == 1:
935
+ expect_resids.shape = (1,)
936
+ assert_equal(residuals.shape, expect_resids.shape)
937
+ else:
938
+ expect_resids = np.array([]).view(type(x))
939
+ assert_almost_equal(residuals, expect_resids)
940
+ assert_(np.issubdtype(residuals.dtype, np.floating))
941
+ assert_(consistent_subclass(x, b))
942
+ assert_(consistent_subclass(residuals, b))
943
+
944
+
945
+ class TestLstsq(LstsqCases):
946
+ def test_future_rcond(self):
947
+ a = np.array([[0., 1., 0., 1., 2., 0.],
948
+ [0., 2., 0., 0., 1., 0.],
949
+ [1., 0., 1., 0., 0., 4.],
950
+ [0., 0., 0., 2., 3., 0.]]).T
951
+
952
+ b = np.array([1, 0, 0, 0, 0, 0])
953
+ with suppress_warnings() as sup:
954
+ w = sup.record(FutureWarning, "`rcond` parameter will change")
955
+ x, residuals, rank, s = linalg.lstsq(a, b)
956
+ assert_(rank == 4)
957
+ x, residuals, rank, s = linalg.lstsq(a, b, rcond=-1)
958
+ assert_(rank == 4)
959
+ x, residuals, rank, s = linalg.lstsq(a, b, rcond=None)
960
+ assert_(rank == 3)
961
+ # Warning should be raised exactly once (first command)
962
+ assert_(len(w) == 1)
963
+
964
+ @pytest.mark.parametrize(["m", "n", "n_rhs"], [
965
+ (4, 2, 2),
966
+ (0, 4, 1),
967
+ (0, 4, 2),
968
+ (4, 0, 1),
969
+ (4, 0, 2),
970
+ (4, 2, 0),
971
+ (0, 0, 0)
972
+ ])
973
+ def test_empty_a_b(self, m, n, n_rhs):
974
+ a = np.arange(m * n).reshape(m, n)
975
+ b = np.ones((m, n_rhs))
976
+ x, residuals, rank, s = linalg.lstsq(a, b, rcond=None)
977
+ if m == 0:
978
+ assert_((x == 0).all())
979
+ assert_equal(x.shape, (n, n_rhs))
980
+ assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,)))
981
+ if m > n and n_rhs > 0:
982
+ # residuals are exactly the squared norms of b's columns
983
+ r = b - np.dot(a, x)
984
+ assert_almost_equal(residuals, (r * r).sum(axis=-2))
985
+ assert_equal(rank, min(m, n))
986
+ assert_equal(s.shape, (min(m, n),))
987
+
988
+ def test_incompatible_dims(self):
989
+ # use modified version of docstring example
990
+ x = np.array([0, 1, 2, 3])
991
+ y = np.array([-1, 0.2, 0.9, 2.1, 3.3])
992
+ A = np.vstack([x, np.ones(len(x))]).T
993
+ with assert_raises_regex(LinAlgError, "Incompatible dimensions"):
994
+ linalg.lstsq(A, y, rcond=None)
995
+
996
+
997
+ @pytest.mark.parametrize('dt', [np.dtype(c) for c in '?bBhHiIqQefdgFDGO'])
998
+ class TestMatrixPower:
999
+
1000
+ rshft_0 = np.eye(4)
1001
+ rshft_1 = rshft_0[[3, 0, 1, 2]]
1002
+ rshft_2 = rshft_0[[2, 3, 0, 1]]
1003
+ rshft_3 = rshft_0[[1, 2, 3, 0]]
1004
+ rshft_all = [rshft_0, rshft_1, rshft_2, rshft_3]
1005
+ noninv = array([[1, 0], [0, 0]])
1006
+ stacked = np.block([[[rshft_0]]]*2)
1007
+ #FIXME the 'e' dtype might work in future
1008
+ dtnoinv = [object, np.dtype('e'), np.dtype('g'), np.dtype('G')]
1009
+
1010
+ def test_large_power(self, dt):
1011
+ rshft = self.rshft_1.astype(dt)
1012
+ assert_equal(
1013
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 0), self.rshft_0)
1014
+ assert_equal(
1015
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 1), self.rshft_1)
1016
+ assert_equal(
1017
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 2), self.rshft_2)
1018
+ assert_equal(
1019
+ matrix_power(rshft, 2**100 + 2**10 + 2**5 + 3), self.rshft_3)
1020
+
1021
+ def test_power_is_zero(self, dt):
1022
+ def tz(M):
1023
+ mz = matrix_power(M, 0)
1024
+ assert_equal(mz, identity_like_generalized(M))
1025
+ assert_equal(mz.dtype, M.dtype)
1026
+
1027
+ for mat in self.rshft_all:
1028
+ tz(mat.astype(dt))
1029
+ if dt != object:
1030
+ tz(self.stacked.astype(dt))
1031
+
1032
+ def test_power_is_one(self, dt):
1033
+ def tz(mat):
1034
+ mz = matrix_power(mat, 1)
1035
+ assert_equal(mz, mat)
1036
+ assert_equal(mz.dtype, mat.dtype)
1037
+
1038
+ for mat in self.rshft_all:
1039
+ tz(mat.astype(dt))
1040
+ if dt != object:
1041
+ tz(self.stacked.astype(dt))
1042
+
1043
+ def test_power_is_two(self, dt):
1044
+ def tz(mat):
1045
+ mz = matrix_power(mat, 2)
1046
+ mmul = matmul if mat.dtype != object else dot
1047
+ assert_equal(mz, mmul(mat, mat))
1048
+ assert_equal(mz.dtype, mat.dtype)
1049
+
1050
+ for mat in self.rshft_all:
1051
+ tz(mat.astype(dt))
1052
+ if dt != object:
1053
+ tz(self.stacked.astype(dt))
1054
+
1055
+ def test_power_is_minus_one(self, dt):
1056
+ def tz(mat):
1057
+ invmat = matrix_power(mat, -1)
1058
+ mmul = matmul if mat.dtype != object else dot
1059
+ assert_almost_equal(
1060
+ mmul(invmat, mat), identity_like_generalized(mat))
1061
+
1062
+ for mat in self.rshft_all:
1063
+ if dt not in self.dtnoinv:
1064
+ tz(mat.astype(dt))
1065
+
1066
+ def test_exceptions_bad_power(self, dt):
1067
+ mat = self.rshft_0.astype(dt)
1068
+ assert_raises(TypeError, matrix_power, mat, 1.5)
1069
+ assert_raises(TypeError, matrix_power, mat, [1])
1070
+
1071
+ def test_exceptions_non_square(self, dt):
1072
+ assert_raises(LinAlgError, matrix_power, np.array([1], dt), 1)
1073
+ assert_raises(LinAlgError, matrix_power, np.array([[1], [2]], dt), 1)
1074
+ assert_raises(LinAlgError, matrix_power, np.ones((4, 3, 2), dt), 1)
1075
+
1076
+ @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
1077
+ def test_exceptions_not_invertible(self, dt):
1078
+ if dt in self.dtnoinv:
1079
+ return
1080
+ mat = self.noninv.astype(dt)
1081
+ assert_raises(LinAlgError, matrix_power, mat, -1)
1082
+
1083
+
1084
+ class TestEigvalshCases(HermitianTestCase, HermitianGeneralizedTestCase):
1085
+
1086
+ def do(self, a, b, tags):
1087
+ # note that eigenvalue arrays returned by eig must be sorted since
1088
+ # their order isn't guaranteed.
1089
+ ev = linalg.eigvalsh(a, 'L')
1090
+ evalues, evectors = linalg.eig(a)
1091
+ evalues.sort(axis=-1)
1092
+ assert_allclose(ev, evalues, rtol=get_rtol(ev.dtype))
1093
+
1094
+ ev2 = linalg.eigvalsh(a, 'U')
1095
+ assert_allclose(ev2, evalues, rtol=get_rtol(ev.dtype))
1096
+
1097
+
1098
+ class TestEigvalsh:
1099
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
1100
+ def test_types(self, dtype):
1101
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1102
+ w = np.linalg.eigvalsh(x)
1103
+ assert_equal(w.dtype, get_real_dtype(dtype))
1104
+
1105
+ def test_invalid(self):
1106
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
1107
+ assert_raises(ValueError, np.linalg.eigvalsh, x, UPLO="lrong")
1108
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "lower")
1109
+ assert_raises(ValueError, np.linalg.eigvalsh, x, "upper")
1110
+
1111
+ def test_UPLO(self):
1112
+ Klo = np.array([[0, 0], [1, 0]], dtype=np.double)
1113
+ Kup = np.array([[0, 1], [0, 0]], dtype=np.double)
1114
+ tgt = np.array([-1, 1], dtype=np.double)
1115
+ rtol = get_rtol(np.double)
1116
+
1117
+ # Check default is 'L'
1118
+ w = np.linalg.eigvalsh(Klo)
1119
+ assert_allclose(w, tgt, rtol=rtol)
1120
+ # Check 'L'
1121
+ w = np.linalg.eigvalsh(Klo, UPLO='L')
1122
+ assert_allclose(w, tgt, rtol=rtol)
1123
+ # Check 'l'
1124
+ w = np.linalg.eigvalsh(Klo, UPLO='l')
1125
+ assert_allclose(w, tgt, rtol=rtol)
1126
+ # Check 'U'
1127
+ w = np.linalg.eigvalsh(Kup, UPLO='U')
1128
+ assert_allclose(w, tgt, rtol=rtol)
1129
+ # Check 'u'
1130
+ w = np.linalg.eigvalsh(Kup, UPLO='u')
1131
+ assert_allclose(w, tgt, rtol=rtol)
1132
+
1133
+ def test_0_size(self):
1134
+ # Check that all kinds of 0-sized arrays work
1135
+ class ArraySubclass(np.ndarray):
1136
+ pass
1137
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
1138
+ res = linalg.eigvalsh(a)
1139
+ assert_(res.dtype.type is np.float64)
1140
+ assert_equal((0, 1), res.shape)
1141
+ # This is just for documentation, it might make sense to change:
1142
+ assert_(isinstance(res, np.ndarray))
1143
+
1144
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
1145
+ res = linalg.eigvalsh(a)
1146
+ assert_(res.dtype.type is np.float32)
1147
+ assert_equal((0,), res.shape)
1148
+ # This is just for documentation, it might make sense to change:
1149
+ assert_(isinstance(res, np.ndarray))
1150
+
1151
+
1152
+ class TestEighCases(HermitianTestCase, HermitianGeneralizedTestCase):
1153
+
1154
+ def do(self, a, b, tags):
1155
+ # note that eigenvalue arrays returned by eig must be sorted since
1156
+ # their order isn't guaranteed.
1157
+ res = linalg.eigh(a)
1158
+ ev, evc = res.eigenvalues, res.eigenvectors
1159
+ evalues, evectors = linalg.eig(a)
1160
+ evalues.sort(axis=-1)
1161
+ assert_almost_equal(ev, evalues)
1162
+
1163
+ assert_allclose(dot_generalized(a, evc),
1164
+ np.asarray(ev)[..., None, :] * np.asarray(evc),
1165
+ rtol=get_rtol(ev.dtype))
1166
+
1167
+ ev2, evc2 = linalg.eigh(a, 'U')
1168
+ assert_almost_equal(ev2, evalues)
1169
+
1170
+ assert_allclose(dot_generalized(a, evc2),
1171
+ np.asarray(ev2)[..., None, :] * np.asarray(evc2),
1172
+ rtol=get_rtol(ev.dtype), err_msg=repr(a))
1173
+
1174
+
1175
+ class TestEigh:
1176
+ @pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
1177
+ def test_types(self, dtype):
1178
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1179
+ w, v = np.linalg.eigh(x)
1180
+ assert_equal(w.dtype, get_real_dtype(dtype))
1181
+ assert_equal(v.dtype, dtype)
1182
+
1183
+ def test_invalid(self):
1184
+ x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
1185
+ assert_raises(ValueError, np.linalg.eigh, x, UPLO="lrong")
1186
+ assert_raises(ValueError, np.linalg.eigh, x, "lower")
1187
+ assert_raises(ValueError, np.linalg.eigh, x, "upper")
1188
+
1189
+ def test_UPLO(self):
1190
+ Klo = np.array([[0, 0], [1, 0]], dtype=np.double)
1191
+ Kup = np.array([[0, 1], [0, 0]], dtype=np.double)
1192
+ tgt = np.array([-1, 1], dtype=np.double)
1193
+ rtol = get_rtol(np.double)
1194
+
1195
+ # Check default is 'L'
1196
+ w, v = np.linalg.eigh(Klo)
1197
+ assert_allclose(w, tgt, rtol=rtol)
1198
+ # Check 'L'
1199
+ w, v = np.linalg.eigh(Klo, UPLO='L')
1200
+ assert_allclose(w, tgt, rtol=rtol)
1201
+ # Check 'l'
1202
+ w, v = np.linalg.eigh(Klo, UPLO='l')
1203
+ assert_allclose(w, tgt, rtol=rtol)
1204
+ # Check 'U'
1205
+ w, v = np.linalg.eigh(Kup, UPLO='U')
1206
+ assert_allclose(w, tgt, rtol=rtol)
1207
+ # Check 'u'
1208
+ w, v = np.linalg.eigh(Kup, UPLO='u')
1209
+ assert_allclose(w, tgt, rtol=rtol)
1210
+
1211
+ def test_0_size(self):
1212
+ # Check that all kinds of 0-sized arrays work
1213
+ class ArraySubclass(np.ndarray):
1214
+ pass
1215
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
1216
+ res, res_v = linalg.eigh(a)
1217
+ assert_(res_v.dtype.type is np.float64)
1218
+ assert_(res.dtype.type is np.float64)
1219
+ assert_equal(a.shape, res_v.shape)
1220
+ assert_equal((0, 1), res.shape)
1221
+ # This is just for documentation, it might make sense to change:
1222
+ assert_(isinstance(a, np.ndarray))
1223
+
1224
+ a = np.zeros((0, 0), dtype=np.complex64).view(ArraySubclass)
1225
+ res, res_v = linalg.eigh(a)
1226
+ assert_(res_v.dtype.type is np.complex64)
1227
+ assert_(res.dtype.type is np.float32)
1228
+ assert_equal(a.shape, res_v.shape)
1229
+ assert_equal((0,), res.shape)
1230
+ # This is just for documentation, it might make sense to change:
1231
+ assert_(isinstance(a, np.ndarray))
1232
+
1233
+
1234
+ class _TestNormBase:
1235
+ dt = None
1236
+ dec = None
1237
+
1238
+ @staticmethod
1239
+ def check_dtype(x, res):
1240
+ if issubclass(x.dtype.type, np.inexact):
1241
+ assert_equal(res.dtype, x.real.dtype)
1242
+ else:
1243
+ # For integer input, don't have to test float precision of output.
1244
+ assert_(issubclass(res.dtype.type, np.floating))
1245
+
1246
+
1247
+ class _TestNormGeneral(_TestNormBase):
1248
+
1249
+ def test_empty(self):
1250
+ assert_equal(norm([]), 0.0)
1251
+ assert_equal(norm(array([], dtype=self.dt)), 0.0)
1252
+ assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0)
1253
+
1254
+ def test_vector_return_type(self):
1255
+ a = np.array([1, 0, 1])
1256
+
1257
+ exact_types = np.typecodes['AllInteger']
1258
+ inexact_types = np.typecodes['AllFloat']
1259
+
1260
+ all_types = exact_types + inexact_types
1261
+
1262
+ for each_type in all_types:
1263
+ at = a.astype(each_type)
1264
+
1265
+ an = norm(at, -np.inf)
1266
+ self.check_dtype(at, an)
1267
+ assert_almost_equal(an, 0.0)
1268
+
1269
+ with suppress_warnings() as sup:
1270
+ sup.filter(RuntimeWarning, "divide by zero encountered")
1271
+ an = norm(at, -1)
1272
+ self.check_dtype(at, an)
1273
+ assert_almost_equal(an, 0.0)
1274
+
1275
+ an = norm(at, 0)
1276
+ self.check_dtype(at, an)
1277
+ assert_almost_equal(an, 2)
1278
+
1279
+ an = norm(at, 1)
1280
+ self.check_dtype(at, an)
1281
+ assert_almost_equal(an, 2.0)
1282
+
1283
+ an = norm(at, 2)
1284
+ self.check_dtype(at, an)
1285
+ assert_almost_equal(an, an.dtype.type(2.0)**an.dtype.type(1.0/2.0))
1286
+
1287
+ an = norm(at, 4)
1288
+ self.check_dtype(at, an)
1289
+ assert_almost_equal(an, an.dtype.type(2.0)**an.dtype.type(1.0/4.0))
1290
+
1291
+ an = norm(at, np.inf)
1292
+ self.check_dtype(at, an)
1293
+ assert_almost_equal(an, 1.0)
1294
+
1295
+ def test_vector(self):
1296
+ a = [1, 2, 3, 4]
1297
+ b = [-1, -2, -3, -4]
1298
+ c = [-1, 2, -3, 4]
1299
+
1300
+ def _test(v):
1301
+ np.testing.assert_almost_equal(norm(v), 30 ** 0.5,
1302
+ decimal=self.dec)
1303
+ np.testing.assert_almost_equal(norm(v, inf), 4.0,
1304
+ decimal=self.dec)
1305
+ np.testing.assert_almost_equal(norm(v, -inf), 1.0,
1306
+ decimal=self.dec)
1307
+ np.testing.assert_almost_equal(norm(v, 1), 10.0,
1308
+ decimal=self.dec)
1309
+ np.testing.assert_almost_equal(norm(v, -1), 12.0 / 25,
1310
+ decimal=self.dec)
1311
+ np.testing.assert_almost_equal(norm(v, 2), 30 ** 0.5,
1312
+ decimal=self.dec)
1313
+ np.testing.assert_almost_equal(norm(v, -2), ((205. / 144) ** -0.5),
1314
+ decimal=self.dec)
1315
+ np.testing.assert_almost_equal(norm(v, 0), 4,
1316
+ decimal=self.dec)
1317
+
1318
+ for v in (a, b, c,):
1319
+ _test(v)
1320
+
1321
+ for v in (array(a, dtype=self.dt), array(b, dtype=self.dt),
1322
+ array(c, dtype=self.dt)):
1323
+ _test(v)
1324
+
1325
+ def test_axis(self):
1326
+ # Vector norms.
1327
+ # Compare the use of `axis` with computing the norm of each row
1328
+ # or column separately.
1329
+ A = array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
1330
+ for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]:
1331
+ expected0 = [norm(A[:, k], ord=order) for k in range(A.shape[1])]
1332
+ assert_almost_equal(norm(A, ord=order, axis=0), expected0)
1333
+ expected1 = [norm(A[k, :], ord=order) for k in range(A.shape[0])]
1334
+ assert_almost_equal(norm(A, ord=order, axis=1), expected1)
1335
+
1336
+ # Matrix norms.
1337
+ B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
1338
+ nd = B.ndim
1339
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro']:
1340
+ for axis in itertools.combinations(range(-nd, nd), 2):
1341
+ row_axis, col_axis = axis
1342
+ if row_axis < 0:
1343
+ row_axis += nd
1344
+ if col_axis < 0:
1345
+ col_axis += nd
1346
+ if row_axis == col_axis:
1347
+ assert_raises(ValueError, norm, B, ord=order, axis=axis)
1348
+ else:
1349
+ n = norm(B, ord=order, axis=axis)
1350
+
1351
+ # The logic using k_index only works for nd = 3.
1352
+ # This has to be changed if nd is increased.
1353
+ k_index = nd - (row_axis + col_axis)
1354
+ if row_axis < col_axis:
1355
+ expected = [norm(B[:].take(k, axis=k_index), ord=order)
1356
+ for k in range(B.shape[k_index])]
1357
+ else:
1358
+ expected = [norm(B[:].take(k, axis=k_index).T, ord=order)
1359
+ for k in range(B.shape[k_index])]
1360
+ assert_almost_equal(n, expected)
1361
+
1362
+ def test_keepdims(self):
1363
+ A = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
1364
+
1365
+ allclose_err = 'order {0}, axis = {1}'
1366
+ shape_err = 'Shape mismatch found {0}, expected {1}, order={2}, axis={3}'
1367
+
1368
+ # check the order=None, axis=None case
1369
+ expected = norm(A, ord=None, axis=None)
1370
+ found = norm(A, ord=None, axis=None, keepdims=True)
1371
+ assert_allclose(np.squeeze(found), expected,
1372
+ err_msg=allclose_err.format(None, None))
1373
+ expected_shape = (1, 1, 1)
1374
+ assert_(found.shape == expected_shape,
1375
+ shape_err.format(found.shape, expected_shape, None, None))
1376
+
1377
+ # Vector norms.
1378
+ for order in [None, -1, 0, 1, 2, 3, np.Inf, -np.Inf]:
1379
+ for k in range(A.ndim):
1380
+ expected = norm(A, ord=order, axis=k)
1381
+ found = norm(A, ord=order, axis=k, keepdims=True)
1382
+ assert_allclose(np.squeeze(found), expected,
1383
+ err_msg=allclose_err.format(order, k))
1384
+ expected_shape = list(A.shape)
1385
+ expected_shape[k] = 1
1386
+ expected_shape = tuple(expected_shape)
1387
+ assert_(found.shape == expected_shape,
1388
+ shape_err.format(found.shape, expected_shape, order, k))
1389
+
1390
+ # Matrix norms.
1391
+ for order in [None, -2, 2, -1, 1, np.Inf, -np.Inf, 'fro', 'nuc']:
1392
+ for k in itertools.permutations(range(A.ndim), 2):
1393
+ expected = norm(A, ord=order, axis=k)
1394
+ found = norm(A, ord=order, axis=k, keepdims=True)
1395
+ assert_allclose(np.squeeze(found), expected,
1396
+ err_msg=allclose_err.format(order, k))
1397
+ expected_shape = list(A.shape)
1398
+ expected_shape[k[0]] = 1
1399
+ expected_shape[k[1]] = 1
1400
+ expected_shape = tuple(expected_shape)
1401
+ assert_(found.shape == expected_shape,
1402
+ shape_err.format(found.shape, expected_shape, order, k))
1403
+
1404
+
1405
+ class _TestNorm2D(_TestNormBase):
1406
+ # Define the part for 2d arrays separately, so we can subclass this
1407
+ # and run the tests using np.matrix in matrixlib.tests.test_matrix_linalg.
1408
+ array = np.array
1409
+
1410
+ def test_matrix_empty(self):
1411
+ assert_equal(norm(self.array([[]], dtype=self.dt)), 0.0)
1412
+
1413
+ def test_matrix_return_type(self):
1414
+ a = self.array([[1, 0, 1], [0, 1, 1]])
1415
+
1416
+ exact_types = np.typecodes['AllInteger']
1417
+
1418
+ # float32, complex64, float64, complex128 types are the only types
1419
+ # allowed by `linalg`, which performs the matrix operations used
1420
+ # within `norm`.
1421
+ inexact_types = 'fdFD'
1422
+
1423
+ all_types = exact_types + inexact_types
1424
+
1425
+ for each_type in all_types:
1426
+ at = a.astype(each_type)
1427
+
1428
+ an = norm(at, -np.inf)
1429
+ self.check_dtype(at, an)
1430
+ assert_almost_equal(an, 2.0)
1431
+
1432
+ with suppress_warnings() as sup:
1433
+ sup.filter(RuntimeWarning, "divide by zero encountered")
1434
+ an = norm(at, -1)
1435
+ self.check_dtype(at, an)
1436
+ assert_almost_equal(an, 1.0)
1437
+
1438
+ an = norm(at, 1)
1439
+ self.check_dtype(at, an)
1440
+ assert_almost_equal(an, 2.0)
1441
+
1442
+ an = norm(at, 2)
1443
+ self.check_dtype(at, an)
1444
+ assert_almost_equal(an, 3.0**(1.0/2.0))
1445
+
1446
+ an = norm(at, -2)
1447
+ self.check_dtype(at, an)
1448
+ assert_almost_equal(an, 1.0)
1449
+
1450
+ an = norm(at, np.inf)
1451
+ self.check_dtype(at, an)
1452
+ assert_almost_equal(an, 2.0)
1453
+
1454
+ an = norm(at, 'fro')
1455
+ self.check_dtype(at, an)
1456
+ assert_almost_equal(an, 2.0)
1457
+
1458
+ an = norm(at, 'nuc')
1459
+ self.check_dtype(at, an)
1460
+ # Lower bar needed to support low precision floats.
1461
+ # They end up being off by 1 in the 7th place.
1462
+ np.testing.assert_almost_equal(an, 2.7320508075688772, decimal=6)
1463
+
1464
+ def test_matrix_2x2(self):
1465
+ A = self.array([[1, 3], [5, 7]], dtype=self.dt)
1466
+ assert_almost_equal(norm(A), 84 ** 0.5)
1467
+ assert_almost_equal(norm(A, 'fro'), 84 ** 0.5)
1468
+ assert_almost_equal(norm(A, 'nuc'), 10.0)
1469
+ assert_almost_equal(norm(A, inf), 12.0)
1470
+ assert_almost_equal(norm(A, -inf), 4.0)
1471
+ assert_almost_equal(norm(A, 1), 10.0)
1472
+ assert_almost_equal(norm(A, -1), 6.0)
1473
+ assert_almost_equal(norm(A, 2), 9.1231056256176615)
1474
+ assert_almost_equal(norm(A, -2), 0.87689437438234041)
1475
+
1476
+ assert_raises(ValueError, norm, A, 'nofro')
1477
+ assert_raises(ValueError, norm, A, -3)
1478
+ assert_raises(ValueError, norm, A, 0)
1479
+
1480
+ def test_matrix_3x3(self):
1481
+ # This test has been added because the 2x2 example
1482
+ # happened to have equal nuclear norm and induced 1-norm.
1483
+ # The 1/10 scaling factor accommodates the absolute tolerance
1484
+ # used in assert_almost_equal.
1485
+ A = (1 / 10) * \
1486
+ self.array([[1, 2, 3], [6, 0, 5], [3, 2, 1]], dtype=self.dt)
1487
+ assert_almost_equal(norm(A), (1 / 10) * 89 ** 0.5)
1488
+ assert_almost_equal(norm(A, 'fro'), (1 / 10) * 89 ** 0.5)
1489
+ assert_almost_equal(norm(A, 'nuc'), 1.3366836911774836)
1490
+ assert_almost_equal(norm(A, inf), 1.1)
1491
+ assert_almost_equal(norm(A, -inf), 0.6)
1492
+ assert_almost_equal(norm(A, 1), 1.0)
1493
+ assert_almost_equal(norm(A, -1), 0.4)
1494
+ assert_almost_equal(norm(A, 2), 0.88722940323461277)
1495
+ assert_almost_equal(norm(A, -2), 0.19456584790481812)
1496
+
1497
+ def test_bad_args(self):
1498
+ # Check that bad arguments raise the appropriate exceptions.
1499
+
1500
+ A = self.array([[1, 2, 3], [4, 5, 6]], dtype=self.dt)
1501
+ B = np.arange(1, 25, dtype=self.dt).reshape(2, 3, 4)
1502
+
1503
+ # Using `axis=<integer>` or passing in a 1-D array implies vector
1504
+ # norms are being computed, so also using `ord='fro'`
1505
+ # or `ord='nuc'` or any other string raises a ValueError.
1506
+ assert_raises(ValueError, norm, A, 'fro', 0)
1507
+ assert_raises(ValueError, norm, A, 'nuc', 0)
1508
+ assert_raises(ValueError, norm, [3, 4], 'fro', None)
1509
+ assert_raises(ValueError, norm, [3, 4], 'nuc', None)
1510
+ assert_raises(ValueError, norm, [3, 4], 'test', None)
1511
+
1512
+ # Similarly, norm should raise an exception when ord is any finite
1513
+ # number other than 1, 2, -1 or -2 when computing matrix norms.
1514
+ for order in [0, 3]:
1515
+ assert_raises(ValueError, norm, A, order, None)
1516
+ assert_raises(ValueError, norm, A, order, (0, 1))
1517
+ assert_raises(ValueError, norm, B, order, (1, 2))
1518
+
1519
+ # Invalid axis
1520
+ assert_raises(np.AxisError, norm, B, None, 3)
1521
+ assert_raises(np.AxisError, norm, B, None, (2, 3))
1522
+ assert_raises(ValueError, norm, B, None, (0, 1, 2))
1523
+
1524
+
1525
+ class _TestNorm(_TestNorm2D, _TestNormGeneral):
1526
+ pass
1527
+
1528
+
1529
+ class TestNorm_NonSystematic:
1530
+
1531
+ def test_longdouble_norm(self):
1532
+ # Non-regression test: p-norm of longdouble would previously raise
1533
+ # UnboundLocalError.
1534
+ x = np.arange(10, dtype=np.longdouble)
1535
+ old_assert_almost_equal(norm(x, ord=3), 12.65, decimal=2)
1536
+
1537
+ def test_intmin(self):
1538
+ # Non-regression test: p-norm of signed integer would previously do
1539
+ # float cast and abs in the wrong order.
1540
+ x = np.array([-2 ** 31], dtype=np.int32)
1541
+ old_assert_almost_equal(norm(x, ord=3), 2 ** 31, decimal=5)
1542
+
1543
+ def test_complex_high_ord(self):
1544
+ # gh-4156
1545
+ d = np.empty((2,), dtype=np.clongdouble)
1546
+ d[0] = 6 + 7j
1547
+ d[1] = -6 + 7j
1548
+ res = 11.615898132184
1549
+ old_assert_almost_equal(np.linalg.norm(d, ord=3), res, decimal=10)
1550
+ d = d.astype(np.complex128)
1551
+ old_assert_almost_equal(np.linalg.norm(d, ord=3), res, decimal=9)
1552
+ d = d.astype(np.complex64)
1553
+ old_assert_almost_equal(np.linalg.norm(d, ord=3), res, decimal=5)
1554
+
1555
+
1556
+ # Separate definitions so we can use them for matrix tests.
1557
+ class _TestNormDoubleBase(_TestNormBase):
1558
+ dt = np.double
1559
+ dec = 12
1560
+
1561
+
1562
+ class _TestNormSingleBase(_TestNormBase):
1563
+ dt = np.float32
1564
+ dec = 6
1565
+
1566
+
1567
+ class _TestNormInt64Base(_TestNormBase):
1568
+ dt = np.int64
1569
+ dec = 12
1570
+
1571
+
1572
+ class TestNormDouble(_TestNorm, _TestNormDoubleBase):
1573
+ pass
1574
+
1575
+
1576
+ class TestNormSingle(_TestNorm, _TestNormSingleBase):
1577
+ pass
1578
+
1579
+
1580
+ class TestNormInt64(_TestNorm, _TestNormInt64Base):
1581
+ pass
1582
+
1583
+
1584
+ class TestMatrixRank:
1585
+
1586
+ def test_matrix_rank(self):
1587
+ # Full rank matrix
1588
+ assert_equal(4, matrix_rank(np.eye(4)))
1589
+ # rank deficient matrix
1590
+ I = np.eye(4)
1591
+ I[-1, -1] = 0.
1592
+ assert_equal(matrix_rank(I), 3)
1593
+ # All zeros - zero rank
1594
+ assert_equal(matrix_rank(np.zeros((4, 4))), 0)
1595
+ # 1 dimension - rank 1 unless all 0
1596
+ assert_equal(matrix_rank([1, 0, 0, 0]), 1)
1597
+ assert_equal(matrix_rank(np.zeros((4,))), 0)
1598
+ # accepts array-like
1599
+ assert_equal(matrix_rank([1]), 1)
1600
+ # greater than 2 dimensions treated as stacked matrices
1601
+ ms = np.array([I, np.eye(4), np.zeros((4,4))])
1602
+ assert_equal(matrix_rank(ms), np.array([3, 4, 0]))
1603
+ # works on scalar
1604
+ assert_equal(matrix_rank(1), 1)
1605
+
1606
+ def test_symmetric_rank(self):
1607
+ assert_equal(4, matrix_rank(np.eye(4), hermitian=True))
1608
+ assert_equal(1, matrix_rank(np.ones((4, 4)), hermitian=True))
1609
+ assert_equal(0, matrix_rank(np.zeros((4, 4)), hermitian=True))
1610
+ # rank deficient matrix
1611
+ I = np.eye(4)
1612
+ I[-1, -1] = 0.
1613
+ assert_equal(3, matrix_rank(I, hermitian=True))
1614
+ # manually supplied tolerance
1615
+ I[-1, -1] = 1e-8
1616
+ assert_equal(4, matrix_rank(I, hermitian=True, tol=0.99e-8))
1617
+ assert_equal(3, matrix_rank(I, hermitian=True, tol=1.01e-8))
1618
+
1619
+
1620
+ def test_reduced_rank():
1621
+ # Test matrices with reduced rank
1622
+ rng = np.random.RandomState(20120714)
1623
+ for i in range(100):
1624
+ # Make a rank deficient matrix
1625
+ X = rng.normal(size=(40, 10))
1626
+ X[:, 0] = X[:, 1] + X[:, 2]
1627
+ # Assert that matrix_rank detected deficiency
1628
+ assert_equal(matrix_rank(X), 9)
1629
+ X[:, 3] = X[:, 4] + X[:, 5]
1630
+ assert_equal(matrix_rank(X), 8)
1631
+
1632
+
1633
+ class TestQR:
1634
+ # Define the array class here, so run this on matrices elsewhere.
1635
+ array = np.array
1636
+
1637
+ def check_qr(self, a):
1638
+ # This test expects the argument `a` to be an ndarray or
1639
+ # a subclass of an ndarray of inexact type.
1640
+ a_type = type(a)
1641
+ a_dtype = a.dtype
1642
+ m, n = a.shape
1643
+ k = min(m, n)
1644
+
1645
+ # mode == 'complete'
1646
+ res = linalg.qr(a, mode='complete')
1647
+ Q, R = res.Q, res.R
1648
+ assert_(Q.dtype == a_dtype)
1649
+ assert_(R.dtype == a_dtype)
1650
+ assert_(isinstance(Q, a_type))
1651
+ assert_(isinstance(R, a_type))
1652
+ assert_(Q.shape == (m, m))
1653
+ assert_(R.shape == (m, n))
1654
+ assert_almost_equal(dot(Q, R), a)
1655
+ assert_almost_equal(dot(Q.T.conj(), Q), np.eye(m))
1656
+ assert_almost_equal(np.triu(R), R)
1657
+
1658
+ # mode == 'reduced'
1659
+ q1, r1 = linalg.qr(a, mode='reduced')
1660
+ assert_(q1.dtype == a_dtype)
1661
+ assert_(r1.dtype == a_dtype)
1662
+ assert_(isinstance(q1, a_type))
1663
+ assert_(isinstance(r1, a_type))
1664
+ assert_(q1.shape == (m, k))
1665
+ assert_(r1.shape == (k, n))
1666
+ assert_almost_equal(dot(q1, r1), a)
1667
+ assert_almost_equal(dot(q1.T.conj(), q1), np.eye(k))
1668
+ assert_almost_equal(np.triu(r1), r1)
1669
+
1670
+ # mode == 'r'
1671
+ r2 = linalg.qr(a, mode='r')
1672
+ assert_(r2.dtype == a_dtype)
1673
+ assert_(isinstance(r2, a_type))
1674
+ assert_almost_equal(r2, r1)
1675
+
1676
+
1677
+ @pytest.mark.parametrize(["m", "n"], [
1678
+ (3, 0),
1679
+ (0, 3),
1680
+ (0, 0)
1681
+ ])
1682
+ def test_qr_empty(self, m, n):
1683
+ k = min(m, n)
1684
+ a = np.empty((m, n))
1685
+
1686
+ self.check_qr(a)
1687
+
1688
+ h, tau = np.linalg.qr(a, mode='raw')
1689
+ assert_equal(h.dtype, np.double)
1690
+ assert_equal(tau.dtype, np.double)
1691
+ assert_equal(h.shape, (n, m))
1692
+ assert_equal(tau.shape, (k,))
1693
+
1694
+ def test_mode_raw(self):
1695
+ # The factorization is not unique and varies between libraries,
1696
+ # so it is not possible to check against known values. Functional
1697
+ # testing is a possibility, but awaits the exposure of more
1698
+ # of the functions in lapack_lite. Consequently, this test is
1699
+ # very limited in scope. Note that the results are in FORTRAN
1700
+ # order, hence the h arrays are transposed.
1701
+ a = self.array([[1, 2], [3, 4], [5, 6]], dtype=np.double)
1702
+
1703
+ # Test double
1704
+ h, tau = linalg.qr(a, mode='raw')
1705
+ assert_(h.dtype == np.double)
1706
+ assert_(tau.dtype == np.double)
1707
+ assert_(h.shape == (2, 3))
1708
+ assert_(tau.shape == (2,))
1709
+
1710
+ h, tau = linalg.qr(a.T, mode='raw')
1711
+ assert_(h.dtype == np.double)
1712
+ assert_(tau.dtype == np.double)
1713
+ assert_(h.shape == (3, 2))
1714
+ assert_(tau.shape == (2,))
1715
+
1716
+ def test_mode_all_but_economic(self):
1717
+ a = self.array([[1, 2], [3, 4]])
1718
+ b = self.array([[1, 2], [3, 4], [5, 6]])
1719
+ for dt in "fd":
1720
+ m1 = a.astype(dt)
1721
+ m2 = b.astype(dt)
1722
+ self.check_qr(m1)
1723
+ self.check_qr(m2)
1724
+ self.check_qr(m2.T)
1725
+
1726
+ for dt in "fd":
1727
+ m1 = 1 + 1j * a.astype(dt)
1728
+ m2 = 1 + 1j * b.astype(dt)
1729
+ self.check_qr(m1)
1730
+ self.check_qr(m2)
1731
+ self.check_qr(m2.T)
1732
+
1733
+ def check_qr_stacked(self, a):
1734
+ # This test expects the argument `a` to be an ndarray or
1735
+ # a subclass of an ndarray of inexact type.
1736
+ a_type = type(a)
1737
+ a_dtype = a.dtype
1738
+ m, n = a.shape[-2:]
1739
+ k = min(m, n)
1740
+
1741
+ # mode == 'complete'
1742
+ q, r = linalg.qr(a, mode='complete')
1743
+ assert_(q.dtype == a_dtype)
1744
+ assert_(r.dtype == a_dtype)
1745
+ assert_(isinstance(q, a_type))
1746
+ assert_(isinstance(r, a_type))
1747
+ assert_(q.shape[-2:] == (m, m))
1748
+ assert_(r.shape[-2:] == (m, n))
1749
+ assert_almost_equal(matmul(q, r), a)
1750
+ I_mat = np.identity(q.shape[-1])
1751
+ stack_I_mat = np.broadcast_to(I_mat,
1752
+ q.shape[:-2] + (q.shape[-1],)*2)
1753
+ assert_almost_equal(matmul(swapaxes(q, -1, -2).conj(), q), stack_I_mat)
1754
+ assert_almost_equal(np.triu(r[..., :, :]), r)
1755
+
1756
+ # mode == 'reduced'
1757
+ q1, r1 = linalg.qr(a, mode='reduced')
1758
+ assert_(q1.dtype == a_dtype)
1759
+ assert_(r1.dtype == a_dtype)
1760
+ assert_(isinstance(q1, a_type))
1761
+ assert_(isinstance(r1, a_type))
1762
+ assert_(q1.shape[-2:] == (m, k))
1763
+ assert_(r1.shape[-2:] == (k, n))
1764
+ assert_almost_equal(matmul(q1, r1), a)
1765
+ I_mat = np.identity(q1.shape[-1])
1766
+ stack_I_mat = np.broadcast_to(I_mat,
1767
+ q1.shape[:-2] + (q1.shape[-1],)*2)
1768
+ assert_almost_equal(matmul(swapaxes(q1, -1, -2).conj(), q1),
1769
+ stack_I_mat)
1770
+ assert_almost_equal(np.triu(r1[..., :, :]), r1)
1771
+
1772
+ # mode == 'r'
1773
+ r2 = linalg.qr(a, mode='r')
1774
+ assert_(r2.dtype == a_dtype)
1775
+ assert_(isinstance(r2, a_type))
1776
+ assert_almost_equal(r2, r1)
1777
+
1778
+ @pytest.mark.parametrize("size", [
1779
+ (3, 4), (4, 3), (4, 4),
1780
+ (3, 0), (0, 3)])
1781
+ @pytest.mark.parametrize("outer_size", [
1782
+ (2, 2), (2,), (2, 3, 4)])
1783
+ @pytest.mark.parametrize("dt", [
1784
+ np.single, np.double,
1785
+ np.csingle, np.cdouble])
1786
+ def test_stacked_inputs(self, outer_size, size, dt):
1787
+
1788
+ A = np.random.normal(size=outer_size + size).astype(dt)
1789
+ B = np.random.normal(size=outer_size + size).astype(dt)
1790
+ self.check_qr_stacked(A)
1791
+ self.check_qr_stacked(A + 1.j*B)
1792
+
1793
+
1794
+ class TestCholesky:
1795
+ # TODO: are there no other tests for cholesky?
1796
+
1797
+ @pytest.mark.parametrize(
1798
+ 'shape', [(1, 1), (2, 2), (3, 3), (50, 50), (3, 10, 10)]
1799
+ )
1800
+ @pytest.mark.parametrize(
1801
+ 'dtype', (np.float32, np.float64, np.complex64, np.complex128)
1802
+ )
1803
+ def test_basic_property(self, shape, dtype):
1804
+ # Check A = L L^H
1805
+ np.random.seed(1)
1806
+ a = np.random.randn(*shape)
1807
+ if np.issubdtype(dtype, np.complexfloating):
1808
+ a = a + 1j*np.random.randn(*shape)
1809
+
1810
+ t = list(range(len(shape)))
1811
+ t[-2:] = -1, -2
1812
+
1813
+ a = np.matmul(a.transpose(t).conj(), a)
1814
+ a = np.asarray(a, dtype=dtype)
1815
+
1816
+ c = np.linalg.cholesky(a)
1817
+
1818
+ b = np.matmul(c, c.transpose(t).conj())
1819
+ with np._no_nep50_warning():
1820
+ atol = 500 * a.shape[0] * np.finfo(dtype).eps
1821
+ assert_allclose(b, a, atol=atol, err_msg=f'{shape} {dtype}\n{a}\n{c}')
1822
+
1823
+ def test_0_size(self):
1824
+ class ArraySubclass(np.ndarray):
1825
+ pass
1826
+ a = np.zeros((0, 1, 1), dtype=np.int_).view(ArraySubclass)
1827
+ res = linalg.cholesky(a)
1828
+ assert_equal(a.shape, res.shape)
1829
+ assert_(res.dtype.type is np.float64)
1830
+ # for documentation purpose:
1831
+ assert_(isinstance(res, np.ndarray))
1832
+
1833
+ a = np.zeros((1, 0, 0), dtype=np.complex64).view(ArraySubclass)
1834
+ res = linalg.cholesky(a)
1835
+ assert_equal(a.shape, res.shape)
1836
+ assert_(res.dtype.type is np.complex64)
1837
+ assert_(isinstance(res, np.ndarray))
1838
+
1839
+
1840
+ def test_byteorder_check():
1841
+ # Byte order check should pass for native order
1842
+ if sys.byteorder == 'little':
1843
+ native = '<'
1844
+ else:
1845
+ native = '>'
1846
+
1847
+ for dtt in (np.float32, np.float64):
1848
+ arr = np.eye(4, dtype=dtt)
1849
+ n_arr = arr.newbyteorder(native)
1850
+ sw_arr = arr.newbyteorder('S').byteswap()
1851
+ assert_equal(arr.dtype.byteorder, '=')
1852
+ for routine in (linalg.inv, linalg.det, linalg.pinv):
1853
+ # Normal call
1854
+ res = routine(arr)
1855
+ # Native but not '='
1856
+ assert_array_equal(res, routine(n_arr))
1857
+ # Swapped
1858
+ assert_array_equal(res, routine(sw_arr))
1859
+
1860
+
1861
+ @pytest.mark.skipif(IS_WASM, reason="fp errors don't work in wasm")
1862
+ def test_generalized_raise_multiloop():
1863
+ # It should raise an error even if the error doesn't occur in the
1864
+ # last iteration of the ufunc inner loop
1865
+
1866
+ invertible = np.array([[1, 2], [3, 4]])
1867
+ non_invertible = np.array([[1, 1], [1, 1]])
1868
+
1869
+ x = np.zeros([4, 4, 2, 2])[1::2]
1870
+ x[...] = invertible
1871
+ x[0, 0] = non_invertible
1872
+
1873
+ assert_raises(np.linalg.LinAlgError, np.linalg.inv, x)
1874
+
1875
+
1876
+ def test_xerbla_override():
1877
+ # Check that our xerbla has been successfully linked in. If it is not,
1878
+ # the default xerbla routine is called, which prints a message to stdout
1879
+ # and may, or may not, abort the process depending on the LAPACK package.
1880
+
1881
+ XERBLA_OK = 255
1882
+
1883
+ try:
1884
+ pid = os.fork()
1885
+ except (OSError, AttributeError):
1886
+ # fork failed, or not running on POSIX
1887
+ pytest.skip("Not POSIX or fork failed.")
1888
+
1889
+ if pid == 0:
1890
+ # child; close i/o file handles
1891
+ os.close(1)
1892
+ os.close(0)
1893
+ # Avoid producing core files.
1894
+ import resource
1895
+ resource.setrlimit(resource.RLIMIT_CORE, (0, 0))
1896
+ # These calls may abort.
1897
+ try:
1898
+ np.linalg.lapack_lite.xerbla()
1899
+ except ValueError:
1900
+ pass
1901
+ except Exception:
1902
+ os._exit(os.EX_CONFIG)
1903
+
1904
+ try:
1905
+ a = np.array([[1.]])
1906
+ np.linalg.lapack_lite.dorgqr(
1907
+ 1, 1, 1, a,
1908
+ 0, # <- invalid value
1909
+ a, a, 0, 0)
1910
+ except ValueError as e:
1911
+ if "DORGQR parameter number 5" in str(e):
1912
+ # success, reuse error code to mark success as
1913
+ # FORTRAN STOP returns as success.
1914
+ os._exit(XERBLA_OK)
1915
+
1916
+ # Did not abort, but our xerbla was not linked in.
1917
+ os._exit(os.EX_CONFIG)
1918
+ else:
1919
+ # parent
1920
+ pid, status = os.wait()
1921
+ if os.WEXITSTATUS(status) != XERBLA_OK:
1922
+ pytest.skip('Numpy xerbla not linked in.')
1923
+
1924
+
1925
+ @pytest.mark.skipif(IS_WASM, reason="Cannot start subprocess")
1926
+ @pytest.mark.slow
1927
+ def test_sdot_bug_8577():
1928
+ # Regression test that loading certain other libraries does not
1929
+ # result to wrong results in float32 linear algebra.
1930
+ #
1931
+ # There's a bug gh-8577 on OSX that can trigger this, and perhaps
1932
+ # there are also other situations in which it occurs.
1933
+ #
1934
+ # Do the check in a separate process.
1935
+
1936
+ bad_libs = ['PyQt5.QtWidgets', 'IPython']
1937
+
1938
+ template = textwrap.dedent("""
1939
+ import sys
1940
+ {before}
1941
+ try:
1942
+ import {bad_lib}
1943
+ except ImportError:
1944
+ sys.exit(0)
1945
+ {after}
1946
+ x = np.ones(2, dtype=np.float32)
1947
+ sys.exit(0 if np.allclose(x.dot(x), 2.0) else 1)
1948
+ """)
1949
+
1950
+ for bad_lib in bad_libs:
1951
+ code = template.format(before="import numpy as np", after="",
1952
+ bad_lib=bad_lib)
1953
+ subprocess.check_call([sys.executable, "-c", code])
1954
+
1955
+ # Swapped import order
1956
+ code = template.format(after="import numpy as np", before="",
1957
+ bad_lib=bad_lib)
1958
+ subprocess.check_call([sys.executable, "-c", code])
1959
+
1960
+
1961
+ class TestMultiDot:
1962
+
1963
+ def test_basic_function_with_three_arguments(self):
1964
+ # multi_dot with three arguments uses a fast hand coded algorithm to
1965
+ # determine the optimal order. Therefore test it separately.
1966
+ A = np.random.random((6, 2))
1967
+ B = np.random.random((2, 6))
1968
+ C = np.random.random((6, 2))
1969
+
1970
+ assert_almost_equal(multi_dot([A, B, C]), A.dot(B).dot(C))
1971
+ assert_almost_equal(multi_dot([A, B, C]), np.dot(A, np.dot(B, C)))
1972
+
1973
+ def test_basic_function_with_two_arguments(self):
1974
+ # separate code path with two arguments
1975
+ A = np.random.random((6, 2))
1976
+ B = np.random.random((2, 6))
1977
+
1978
+ assert_almost_equal(multi_dot([A, B]), A.dot(B))
1979
+ assert_almost_equal(multi_dot([A, B]), np.dot(A, B))
1980
+
1981
+ def test_basic_function_with_dynamic_programming_optimization(self):
1982
+ # multi_dot with four or more arguments uses the dynamic programming
1983
+ # optimization and therefore deserve a separate
1984
+ A = np.random.random((6, 2))
1985
+ B = np.random.random((2, 6))
1986
+ C = np.random.random((6, 2))
1987
+ D = np.random.random((2, 1))
1988
+ assert_almost_equal(multi_dot([A, B, C, D]), A.dot(B).dot(C).dot(D))
1989
+
1990
+ def test_vector_as_first_argument(self):
1991
+ # The first argument can be 1-D
1992
+ A1d = np.random.random(2) # 1-D
1993
+ B = np.random.random((2, 6))
1994
+ C = np.random.random((6, 2))
1995
+ D = np.random.random((2, 2))
1996
+
1997
+ # the result should be 1-D
1998
+ assert_equal(multi_dot([A1d, B, C, D]).shape, (2,))
1999
+
2000
+ def test_vector_as_last_argument(self):
2001
+ # The last argument can be 1-D
2002
+ A = np.random.random((6, 2))
2003
+ B = np.random.random((2, 6))
2004
+ C = np.random.random((6, 2))
2005
+ D1d = np.random.random(2) # 1-D
2006
+
2007
+ # the result should be 1-D
2008
+ assert_equal(multi_dot([A, B, C, D1d]).shape, (6,))
2009
+
2010
+ def test_vector_as_first_and_last_argument(self):
2011
+ # The first and last arguments can be 1-D
2012
+ A1d = np.random.random(2) # 1-D
2013
+ B = np.random.random((2, 6))
2014
+ C = np.random.random((6, 2))
2015
+ D1d = np.random.random(2) # 1-D
2016
+
2017
+ # the result should be a scalar
2018
+ assert_equal(multi_dot([A1d, B, C, D1d]).shape, ())
2019
+
2020
+ def test_three_arguments_and_out(self):
2021
+ # multi_dot with three arguments uses a fast hand coded algorithm to
2022
+ # determine the optimal order. Therefore test it separately.
2023
+ A = np.random.random((6, 2))
2024
+ B = np.random.random((2, 6))
2025
+ C = np.random.random((6, 2))
2026
+
2027
+ out = np.zeros((6, 2))
2028
+ ret = multi_dot([A, B, C], out=out)
2029
+ assert out is ret
2030
+ assert_almost_equal(out, A.dot(B).dot(C))
2031
+ assert_almost_equal(out, np.dot(A, np.dot(B, C)))
2032
+
2033
+ def test_two_arguments_and_out(self):
2034
+ # separate code path with two arguments
2035
+ A = np.random.random((6, 2))
2036
+ B = np.random.random((2, 6))
2037
+ out = np.zeros((6, 6))
2038
+ ret = multi_dot([A, B], out=out)
2039
+ assert out is ret
2040
+ assert_almost_equal(out, A.dot(B))
2041
+ assert_almost_equal(out, np.dot(A, B))
2042
+
2043
+ def test_dynamic_programming_optimization_and_out(self):
2044
+ # multi_dot with four or more arguments uses the dynamic programming
2045
+ # optimization and therefore deserve a separate test
2046
+ A = np.random.random((6, 2))
2047
+ B = np.random.random((2, 6))
2048
+ C = np.random.random((6, 2))
2049
+ D = np.random.random((2, 1))
2050
+ out = np.zeros((6, 1))
2051
+ ret = multi_dot([A, B, C, D], out=out)
2052
+ assert out is ret
2053
+ assert_almost_equal(out, A.dot(B).dot(C).dot(D))
2054
+
2055
+ def test_dynamic_programming_logic(self):
2056
+ # Test for the dynamic programming part
2057
+ # This test is directly taken from Cormen page 376.
2058
+ arrays = [np.random.random((30, 35)),
2059
+ np.random.random((35, 15)),
2060
+ np.random.random((15, 5)),
2061
+ np.random.random((5, 10)),
2062
+ np.random.random((10, 20)),
2063
+ np.random.random((20, 25))]
2064
+ m_expected = np.array([[0., 15750., 7875., 9375., 11875., 15125.],
2065
+ [0., 0., 2625., 4375., 7125., 10500.],
2066
+ [0., 0., 0., 750., 2500., 5375.],
2067
+ [0., 0., 0., 0., 1000., 3500.],
2068
+ [0., 0., 0., 0., 0., 5000.],
2069
+ [0., 0., 0., 0., 0., 0.]])
2070
+ s_expected = np.array([[0, 1, 1, 3, 3, 3],
2071
+ [0, 0, 2, 3, 3, 3],
2072
+ [0, 0, 0, 3, 3, 3],
2073
+ [0, 0, 0, 0, 4, 5],
2074
+ [0, 0, 0, 0, 0, 5],
2075
+ [0, 0, 0, 0, 0, 0]], dtype=int)
2076
+ s_expected -= 1 # Cormen uses 1-based index, python does not.
2077
+
2078
+ s, m = _multi_dot_matrix_chain_order(arrays, return_costs=True)
2079
+
2080
+ # Only the upper triangular part (without the diagonal) is interesting.
2081
+ assert_almost_equal(np.triu(s[:-1, 1:]),
2082
+ np.triu(s_expected[:-1, 1:]))
2083
+ assert_almost_equal(np.triu(m), np.triu(m_expected))
2084
+
2085
+ def test_too_few_input_arrays(self):
2086
+ assert_raises(ValueError, multi_dot, [])
2087
+ assert_raises(ValueError, multi_dot, [np.random.random((3, 3))])
2088
+
2089
+
2090
+ class TestTensorinv:
2091
+
2092
+ @pytest.mark.parametrize("arr, ind", [
2093
+ (np.ones((4, 6, 8, 2)), 2),
2094
+ (np.ones((3, 3, 2)), 1),
2095
+ ])
2096
+ def test_non_square_handling(self, arr, ind):
2097
+ with assert_raises(LinAlgError):
2098
+ linalg.tensorinv(arr, ind=ind)
2099
+
2100
+ @pytest.mark.parametrize("shape, ind", [
2101
+ # examples from docstring
2102
+ ((4, 6, 8, 3), 2),
2103
+ ((24, 8, 3), 1),
2104
+ ])
2105
+ def test_tensorinv_shape(self, shape, ind):
2106
+ a = np.eye(24)
2107
+ a.shape = shape
2108
+ ainv = linalg.tensorinv(a=a, ind=ind)
2109
+ expected = a.shape[ind:] + a.shape[:ind]
2110
+ actual = ainv.shape
2111
+ assert_equal(actual, expected)
2112
+
2113
+ @pytest.mark.parametrize("ind", [
2114
+ 0, -2,
2115
+ ])
2116
+ def test_tensorinv_ind_limit(self, ind):
2117
+ a = np.eye(24)
2118
+ a.shape = (4, 6, 8, 3)
2119
+ with assert_raises(ValueError):
2120
+ linalg.tensorinv(a=a, ind=ind)
2121
+
2122
+ def test_tensorinv_result(self):
2123
+ # mimic a docstring example
2124
+ a = np.eye(24)
2125
+ a.shape = (24, 8, 3)
2126
+ ainv = linalg.tensorinv(a, ind=1)
2127
+ b = np.ones(24)
2128
+ assert_allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
2129
+
2130
+
2131
+ class TestTensorsolve:
2132
+
2133
+ @pytest.mark.parametrize("a, axes", [
2134
+ (np.ones((4, 6, 8, 2)), None),
2135
+ (np.ones((3, 3, 2)), (0, 2)),
2136
+ ])
2137
+ def test_non_square_handling(self, a, axes):
2138
+ with assert_raises(LinAlgError):
2139
+ b = np.ones(a.shape[:2])
2140
+ linalg.tensorsolve(a, b, axes=axes)
2141
+
2142
+ @pytest.mark.parametrize("shape",
2143
+ [(2, 3, 6), (3, 4, 4, 3), (0, 3, 3, 0)],
2144
+ )
2145
+ def test_tensorsolve_result(self, shape):
2146
+ a = np.random.randn(*shape)
2147
+ b = np.ones(a.shape[:2])
2148
+ x = np.linalg.tensorsolve(a, b)
2149
+ assert_allclose(np.tensordot(a, x, axes=len(x.shape)), b)
2150
+
2151
+
2152
+ def test_unsupported_commontype():
2153
+ # linalg gracefully handles unsupported type
2154
+ arr = np.array([[1, -2], [2, 5]], dtype='float16')
2155
+ with assert_raises_regex(TypeError, "unsupported in linalg"):
2156
+ linalg.cholesky(arr)
2157
+
2158
+
2159
+ #@pytest.mark.slow
2160
+ #@pytest.mark.xfail(not HAS_LAPACK64, run=False,
2161
+ # reason="Numpy not compiled with 64-bit BLAS/LAPACK")
2162
+ #@requires_memory(free_bytes=16e9)
2163
+ @pytest.mark.skip(reason="Bad memory reports lead to OOM in ci testing")
2164
+ def test_blas64_dot():
2165
+ n = 2**32
2166
+ a = np.zeros([1, n], dtype=np.float32)
2167
+ b = np.ones([1, 1], dtype=np.float32)
2168
+ a[0,-1] = 1
2169
+ c = np.dot(b, a)
2170
+ assert_equal(c[0,-1], 1)
2171
+
2172
+
2173
+ @pytest.mark.xfail(not HAS_LAPACK64,
2174
+ reason="Numpy not compiled with 64-bit BLAS/LAPACK")
2175
+ def test_blas64_geqrf_lwork_smoketest():
2176
+ # Smoke test LAPACK geqrf lwork call with 64-bit integers
2177
+ dtype = np.float64
2178
+ lapack_routine = np.linalg.lapack_lite.dgeqrf
2179
+
2180
+ m = 2**32 + 1
2181
+ n = 2**32 + 1
2182
+ lda = m
2183
+
2184
+ # Dummy arrays, not referenced by the lapack routine, so don't
2185
+ # need to be of the right size
2186
+ a = np.zeros([1, 1], dtype=dtype)
2187
+ work = np.zeros([1], dtype=dtype)
2188
+ tau = np.zeros([1], dtype=dtype)
2189
+
2190
+ # Size query
2191
+ results = lapack_routine(m, n, a, lda, tau, work, -1, 0)
2192
+ assert_equal(results['info'], 0)
2193
+ assert_equal(results['m'], m)
2194
+ assert_equal(results['n'], m)
2195
+
2196
+ # Should result to an integer of a reasonable size
2197
+ lwork = int(work.item())
2198
+ assert_(2**32 < lwork < 2**42)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/linalg/tests/test_regression.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Test functions for linalg module
2
+ """
3
+ import warnings
4
+
5
+ import numpy as np
6
+ from numpy import linalg, arange, float64, array, dot, transpose
7
+ from numpy.testing import (
8
+ assert_, assert_raises, assert_equal, assert_array_equal,
9
+ assert_array_almost_equal, assert_array_less
10
+ )
11
+
12
+
13
+ class TestRegression:
14
+
15
+ def test_eig_build(self):
16
+ # Ticket #652
17
+ rva = array([1.03221168e+02 + 0.j,
18
+ -1.91843603e+01 + 0.j,
19
+ -6.04004526e-01 + 15.84422474j,
20
+ -6.04004526e-01 - 15.84422474j,
21
+ -1.13692929e+01 + 0.j,
22
+ -6.57612485e-01 + 10.41755503j,
23
+ -6.57612485e-01 - 10.41755503j,
24
+ 1.82126812e+01 + 0.j,
25
+ 1.06011014e+01 + 0.j,
26
+ 7.80732773e+00 + 0.j,
27
+ -7.65390898e-01 + 0.j,
28
+ 1.51971555e-15 + 0.j,
29
+ -1.51308713e-15 + 0.j])
30
+ a = arange(13 * 13, dtype=float64)
31
+ a.shape = (13, 13)
32
+ a = a % 17
33
+ va, ve = linalg.eig(a)
34
+ va.sort()
35
+ rva.sort()
36
+ assert_array_almost_equal(va, rva)
37
+
38
+ def test_eigh_build(self):
39
+ # Ticket 662.
40
+ rvals = [68.60568999, 89.57756725, 106.67185574]
41
+
42
+ cov = array([[77.70273908, 3.51489954, 15.64602427],
43
+ [3.51489954, 88.97013878, -1.07431931],
44
+ [15.64602427, -1.07431931, 98.18223512]])
45
+
46
+ vals, vecs = linalg.eigh(cov)
47
+ assert_array_almost_equal(vals, rvals)
48
+
49
+ def test_svd_build(self):
50
+ # Ticket 627.
51
+ a = array([[0., 1.], [1., 1.], [2., 1.], [3., 1.]])
52
+ m, n = a.shape
53
+ u, s, vh = linalg.svd(a)
54
+
55
+ b = dot(transpose(u[:, n:]), a)
56
+
57
+ assert_array_almost_equal(b, np.zeros((2, 2)))
58
+
59
+ def test_norm_vector_badarg(self):
60
+ # Regression for #786: Frobenius norm for vectors raises
61
+ # ValueError.
62
+ assert_raises(ValueError, linalg.norm, array([1., 2., 3.]), 'fro')
63
+
64
+ def test_lapack_endian(self):
65
+ # For bug #1482
66
+ a = array([[5.7998084, -2.1825367],
67
+ [-2.1825367, 9.85910595]], dtype='>f8')
68
+ b = array(a, dtype='<f8')
69
+
70
+ ap = linalg.cholesky(a)
71
+ bp = linalg.cholesky(b)
72
+ assert_array_equal(ap, bp)
73
+
74
+ def test_large_svd_32bit(self):
75
+ # See gh-4442, 64bit would require very large/slow matrices.
76
+ x = np.eye(1000, 66)
77
+ np.linalg.svd(x)
78
+
79
+ def test_svd_no_uv(self):
80
+ # gh-4733
81
+ for shape in (3, 4), (4, 4), (4, 3):
82
+ for t in float, complex:
83
+ a = np.ones(shape, dtype=t)
84
+ w = linalg.svd(a, compute_uv=False)
85
+ c = np.count_nonzero(np.absolute(w) > 0.5)
86
+ assert_equal(c, 1)
87
+ assert_equal(np.linalg.matrix_rank(a), 1)
88
+ assert_array_less(1, np.linalg.norm(a, ord=2))
89
+
90
+ def test_norm_object_array(self):
91
+ # gh-7575
92
+ testvector = np.array([np.array([0, 1]), 0, 0], dtype=object)
93
+
94
+ norm = linalg.norm(testvector)
95
+ assert_array_equal(norm, [0, 1])
96
+ assert_(norm.dtype == np.dtype('float64'))
97
+
98
+ norm = linalg.norm(testvector, ord=1)
99
+ assert_array_equal(norm, [0, 1])
100
+ assert_(norm.dtype != np.dtype('float64'))
101
+
102
+ norm = linalg.norm(testvector, ord=2)
103
+ assert_array_equal(norm, [0, 1])
104
+ assert_(norm.dtype == np.dtype('float64'))
105
+
106
+ assert_raises(ValueError, linalg.norm, testvector, ord='fro')
107
+ assert_raises(ValueError, linalg.norm, testvector, ord='nuc')
108
+ assert_raises(ValueError, linalg.norm, testvector, ord=np.inf)
109
+ assert_raises(ValueError, linalg.norm, testvector, ord=-np.inf)
110
+ assert_raises(ValueError, linalg.norm, testvector, ord=0)
111
+ assert_raises(ValueError, linalg.norm, testvector, ord=-1)
112
+ assert_raises(ValueError, linalg.norm, testvector, ord=-2)
113
+
114
+ testmatrix = np.array([[np.array([0, 1]), 0, 0],
115
+ [0, 0, 0]], dtype=object)
116
+
117
+ norm = linalg.norm(testmatrix)
118
+ assert_array_equal(norm, [0, 1])
119
+ assert_(norm.dtype == np.dtype('float64'))
120
+
121
+ norm = linalg.norm(testmatrix, ord='fro')
122
+ assert_array_equal(norm, [0, 1])
123
+ assert_(norm.dtype == np.dtype('float64'))
124
+
125
+ assert_raises(TypeError, linalg.norm, testmatrix, ord='nuc')
126
+ assert_raises(ValueError, linalg.norm, testmatrix, ord=np.inf)
127
+ assert_raises(ValueError, linalg.norm, testmatrix, ord=-np.inf)
128
+ assert_raises(ValueError, linalg.norm, testmatrix, ord=0)
129
+ assert_raises(ValueError, linalg.norm, testmatrix, ord=1)
130
+ assert_raises(ValueError, linalg.norm, testmatrix, ord=-1)
131
+ assert_raises(TypeError, linalg.norm, testmatrix, ord=2)
132
+ assert_raises(TypeError, linalg.norm, testmatrix, ord=-2)
133
+ assert_raises(ValueError, linalg.norm, testmatrix, ord=3)
134
+
135
+ def test_lstsq_complex_larger_rhs(self):
136
+ # gh-9891
137
+ size = 20
138
+ n_rhs = 70
139
+ G = np.random.randn(size, size) + 1j * np.random.randn(size, size)
140
+ u = np.random.randn(size, n_rhs) + 1j * np.random.randn(size, n_rhs)
141
+ b = G.dot(u)
142
+ # This should work without segmentation fault.
143
+ u_lstsq, res, rank, sv = linalg.lstsq(G, b, rcond=None)
144
+ # check results just in case
145
+ assert_array_almost_equal(u_lstsq, u)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/arithmetic.cpython-312.pyc ADDED
Binary file (12 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/arrayterator.cpython-312.pyc ADDED
Binary file (1.07 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/bitwise_ops.cpython-312.pyc ADDED
Binary file (2.44 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/einsumfunc.cpython-312.pyc ADDED
Binary file (2.32 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/lib_utils.cpython-312.pyc ADDED
Binary file (1.27 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/literal.cpython-312.pyc ADDED
Binary file (2.53 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/multiarray.cpython-312.pyc ADDED
Binary file (3.51 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/numeric.cpython-312.pyc ADDED
Binary file (4.41 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/simple_py3.cpython-312.pyc ADDED
Binary file (356 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/ufuncs.cpython-312.pyc ADDED
Binary file (1.27 kB). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/numpy/typing/tests/data/pass/__pycache__/warnings_and_errors.cpython-312.pyc ADDED
Binary file (556 Bytes). View file
 
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/_convolution_double_backward.h ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from Function.h
5
+
6
+ #include <ATen/Context.h>
7
+ #include <ATen/DeviceGuard.h>
8
+ #include <ATen/TensorUtils.h>
9
+ #include <ATen/TracerMode.h>
10
+ #include <ATen/core/Generator.h>
11
+ #include <ATen/core/Reduction.h>
12
+ #include <ATen/core/Tensor.h>
13
+ #include <c10/core/Scalar.h>
14
+ #include <c10/core/Storage.h>
15
+ #include <c10/core/TensorOptions.h>
16
+ #include <c10/util/Deprecated.h>
17
+ #include <optional>
18
+ #include <string_view>
19
+
20
+
21
+
22
+ #include <ATen/ops/_convolution_double_backward_ops.h>
23
+
24
+ namespace at {
25
+
26
+
27
+ // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
28
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _convolution_double_backward(const ::std::optional<at::Tensor> & ggI, const ::std::optional<at::Tensor> & ggW, const ::std::optional<at::Tensor> & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array<bool,3> output_mask) {
29
+ return at::_ops::_convolution_double_backward::call(ggI, ggW, ggb, gO, weight, self, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask);
30
+ }
31
+ namespace symint {
32
+ template <typename T, typename = std::enable_if_t<std::is_same_v<T, int64_t>>>
33
+ ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _convolution_double_backward(const ::std::optional<at::Tensor> & ggI, const ::std::optional<at::Tensor> & ggW, const ::std::optional<at::Tensor> & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array<bool,3> output_mask) {
34
+ return at::_ops::_convolution_double_backward::call(ggI, ggW, ggb, gO, weight, self, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask);
35
+ }
36
+ }
37
+
38
+ // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
39
+ inline ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _convolution_double_backward_symint(const ::std::optional<at::Tensor> & ggI, const ::std::optional<at::Tensor> & ggW, const ::std::optional<at::Tensor> & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array<bool,3> output_mask) {
40
+ return at::_ops::_convolution_double_backward::call(ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask);
41
+ }
42
+ namespace symint {
43
+ template <typename T, typename = std::enable_if_t<std::is_same_v<T, c10::SymInt>>>
44
+ ::std::tuple<at::Tensor,at::Tensor,at::Tensor> _convolution_double_backward(const ::std::optional<at::Tensor> & ggI, const ::std::optional<at::Tensor> & ggW, const ::std::optional<at::Tensor> & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array<bool,3> output_mask) {
45
+ return at::_ops::_convolution_double_backward::call(ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask);
46
+ }
47
+ }
48
+
49
+ }
50
+
51
+ #else
52
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
53
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/cudnn_convolution_relu_ops.h ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from Operator.h
5
+
6
+ #include <string_view>
7
+ #include <tuple>
8
+ #include <vector>
9
+
10
+ // Forward declarations of any types needed in the operator signatures.
11
+ // We can't directly include these classes because it will cause circular include dependencies.
12
+ // This file is included by TensorBody.h, which defines the Tensor class.
13
+ #include <ATen/core/ATen_fwd.h>
14
+
15
+ namespace at {
16
+ namespace _ops {
17
+
18
+
19
+ struct TORCH_API cudnn_convolution_relu {
20
+ using schema = at::Tensor (const at::Tensor &, const at::Tensor &, const ::std::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymInt);
21
+ using ptr_schema = schema*;
22
+ // See Note [static constexpr char* members for windows NVCC]
23
+ static constexpr const char* name = "aten::cudnn_convolution_relu";
24
+ static constexpr const char* overload_name = "";
25
+ static constexpr const char* schema_str = "cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor";
26
+ static at::Tensor call(const at::Tensor & self, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups);
27
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups);
28
+ };
29
+
30
+ struct TORCH_API cudnn_convolution_relu_out {
31
+ using schema = at::Tensor & (const at::Tensor &, const at::Tensor &, const ::std::optional<at::Tensor> &, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymIntArrayRef, c10::SymInt, at::Tensor &);
32
+ using ptr_schema = schema*;
33
+ // See Note [static constexpr char* members for windows NVCC]
34
+ static constexpr const char* name = "aten::cudnn_convolution_relu";
35
+ static constexpr const char* overload_name = "out";
36
+ static constexpr const char* schema_str = "cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)";
37
+ static at::Tensor & call(const at::Tensor & self, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out);
38
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional<at::Tensor> & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out);
39
+ };
40
+
41
+ }} // namespace at::_ops
42
+
43
+ #else
44
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
45
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_ops.h ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from Operator.h
5
+
6
+ #include <string_view>
7
+ #include <tuple>
8
+ #include <vector>
9
+
10
+ // Forward declarations of any types needed in the operator signatures.
11
+ // We can't directly include these classes because it will cause circular include dependencies.
12
+ // This file is included by TensorBody.h, which defines the Tensor class.
13
+ #include <ATen/core/ATen_fwd.h>
14
+
15
+ namespace at {
16
+ namespace _ops {
17
+
18
+
19
+ struct TORCH_API mkldnn_adaptive_avg_pool2d {
20
+ using schema = at::Tensor (const at::Tensor &, at::IntArrayRef);
21
+ using ptr_schema = schema*;
22
+ // See Note [static constexpr char* members for windows NVCC]
23
+ static constexpr const char* name = "aten::mkldnn_adaptive_avg_pool2d";
24
+ static constexpr const char* overload_name = "";
25
+ static constexpr const char* schema_str = "mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor";
26
+ static at::Tensor call(const at::Tensor & self, at::IntArrayRef output_size);
27
+ static at::Tensor redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size);
28
+ };
29
+
30
+ struct TORCH_API mkldnn_adaptive_avg_pool2d_out {
31
+ using schema = at::Tensor & (const at::Tensor &, at::IntArrayRef, at::Tensor &);
32
+ using ptr_schema = schema*;
33
+ // See Note [static constexpr char* members for windows NVCC]
34
+ static constexpr const char* name = "aten::mkldnn_adaptive_avg_pool2d";
35
+ static constexpr const char* overload_name = "out";
36
+ static constexpr const char* schema_str = "mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)";
37
+ static at::Tensor & call(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out);
38
+ static at::Tensor & redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out);
39
+ };
40
+
41
+ }} // namespace at::_ops
42
+
43
+ #else
44
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
45
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/silu_meta.h ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+
4
+ // @generated by torchgen/gen.py from NativeMetaFunction.h
5
+
6
+ #include <c10/core/Scalar.h>
7
+ #include <c10/core/Storage.h>
8
+ #include <c10/core/TensorOptions.h>
9
+ #include <c10/util/Deprecated.h>
10
+ #include <optional>
11
+ #include <c10/core/QScheme.h>
12
+ #include <ATen/core/Reduction.h>
13
+ #include <ATen/TensorIterator.h>
14
+ #include <ATen/TensorMeta.h>
15
+ #include <tuple>
16
+ #include <vector>
17
+
18
+ namespace at {
19
+ namespace meta {
20
+
21
+ struct TORCH_API structured_silu : public TensorIteratorBase {
22
+
23
+
24
+ void meta(const at::Tensor & self);
25
+ };
26
+
27
+ } // namespace native
28
+ } // namespace at
29
+
30
+ #else
31
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
32
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
4
+
5
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
6
+
7
+ // The only #includes we need are for custom classes that have defaults in the C++ API
8
+ #include <c10/core/MemoryFormat.h>
9
+ #include <c10/core/Scalar.h>
10
+ #include <ATen/core/Reduction.h>
11
+
12
+ // Forward declarations of any types needed in the operator signatures.
13
+ // We can't directly include these classes because it will cause circular include dependencies.
14
+ // This file is included by TensorBody.h, which defines the Tensor class.
15
+ #include <ATen/core/ATen_fwd.h>
16
+
17
+ namespace at {
18
+
19
+ namespace compositeexplicitautogradnonfunctional {
20
+
21
+ TORCH_API at::Tensor slow_conv_transpose2d(const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional<at::Tensor> & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1);
22
+ TORCH_API at::Tensor slow_conv_transpose2d_symint(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional<at::Tensor> & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1));
23
+
24
+ } // namespace compositeexplicitautogradnonfunctional
25
+ } // namespace at
26
+
27
+ #else
28
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
29
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ #pragma once
3
+ // @generated by torchgen/gen.py from DispatchKeyFunction.h
4
+
5
+ // NB: The implementing C++ file is RegisterDispatchKey.cpp
6
+
7
+ // The only #includes we need are for custom classes that have defaults in the C++ API
8
+ #include <c10/core/MemoryFormat.h>
9
+ #include <c10/core/Scalar.h>
10
+ #include <ATen/core/Reduction.h>
11
+
12
+ // Forward declarations of any types needed in the operator signatures.
13
+ // We can't directly include these classes because it will cause circular include dependencies.
14
+ // This file is included by TensorBody.h, which defines the Tensor class.
15
+ #include <ATen/core/ATen_fwd.h>
16
+
17
+ namespace at {
18
+
19
+ namespace compositeexplicitautogradnonfunctional {
20
+
21
+ TORCH_API at::Tensor upsample_bilinear2d(const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional<double> scales_h=::std::nullopt, ::std::optional<double> scales_w=::std::nullopt);
22
+ TORCH_API at::Tensor upsample_bilinear2d_symint(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional<double> scales_h=::std::nullopt, ::std::optional<double> scales_w=::std::nullopt);
23
+
24
+ } // namespace compositeexplicitautogradnonfunctional
25
+ } // namespace at
26
+
27
+ #else
28
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
29
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/any.h ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Protocol Buffers - Google's data interchange format
3
+ // Copyright 2008 Google Inc. All rights reserved.
4
+ // https://developers.google.com/protocol-buffers/
5
+ //
6
+ // Redistribution and use in source and binary forms, with or without
7
+ // modification, are permitted provided that the following conditions are
8
+ // met:
9
+ //
10
+ // * Redistributions of source code must retain the above copyright
11
+ // notice, this list of conditions and the following disclaimer.
12
+ // * Redistributions in binary form must reproduce the above
13
+ // copyright notice, this list of conditions and the following disclaimer
14
+ // in the documentation and/or other materials provided with the
15
+ // distribution.
16
+ // * Neither the name of Google Inc. nor the names of its
17
+ // contributors may be used to endorse or promote products derived from
18
+ // this software without specific prior written permission.
19
+ //
20
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24
+ // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+
32
+ #ifndef GOOGLE_PROTOBUF_ANY_H__
33
+ #define GOOGLE_PROTOBUF_ANY_H__
34
+
35
+ #include <string>
36
+
37
+ #include <google/protobuf/stubs/common.h>
38
+ #include <google/protobuf/arenastring.h>
39
+ #include <google/protobuf/message_lite.h>
40
+
41
+ #include <google/protobuf/port_def.inc>
42
+
43
+ namespace google {
44
+ namespace protobuf {
45
+
46
+ class FieldDescriptor;
47
+ class Message;
48
+
49
+ namespace internal {
50
+
51
+ extern const char kAnyFullTypeName[]; // "google.protobuf.Any".
52
+ extern const char kTypeGoogleApisComPrefix[]; // "type.googleapis.com/".
53
+ extern const char kTypeGoogleProdComPrefix[]; // "type.googleprod.com/".
54
+
55
+ std::string GetTypeUrl(StringPiece message_name,
56
+ StringPiece type_url_prefix);
57
+
58
+ // Helper class used to implement google::protobuf::Any.
59
+ class PROTOBUF_EXPORT AnyMetadata {
60
+ typedef ArenaStringPtr UrlType;
61
+ typedef ArenaStringPtr ValueType;
62
+ public:
63
+ // AnyMetadata does not take ownership of "type_url" and "value".
64
+ AnyMetadata(UrlType* type_url, ValueType* value);
65
+
66
+ // Packs a message using the default type URL prefix: "type.googleapis.com".
67
+ // The resulted type URL will be "type.googleapis.com/<message_full_name>".
68
+ template <typename T>
69
+ void PackFrom(const T& message) {
70
+ InternalPackFrom(message, kTypeGoogleApisComPrefix, T::FullMessageName());
71
+ }
72
+
73
+ void PackFrom(const Message& message);
74
+
75
+ // Packs a message using the given type URL prefix. The type URL will be
76
+ // constructed by concatenating the message type's full name to the prefix
77
+ // with an optional "/" separator if the prefix doesn't already end with "/".
78
+ // For example, both PackFrom(message, "type.googleapis.com") and
79
+ // PackFrom(message, "type.googleapis.com/") yield the same result type
80
+ // URL: "type.googleapis.com/<message_full_name>".
81
+ template <typename T>
82
+ void PackFrom(const T& message, StringPiece type_url_prefix) {
83
+ InternalPackFrom(message, type_url_prefix, T::FullMessageName());
84
+ }
85
+
86
+ void PackFrom(const Message& message, const std::string& type_url_prefix);
87
+
88
+ // Unpacks the payload into the given message. Returns false if the message's
89
+ // type doesn't match the type specified in the type URL (i.e., the full
90
+ // name after the last "/" of the type URL doesn't match the message's actual
91
+ // full name) or parsing the payload has failed.
92
+ template <typename T>
93
+ bool UnpackTo(T* message) const {
94
+ return InternalUnpackTo(T::FullMessageName(), message);
95
+ }
96
+
97
+ bool UnpackTo(Message* message) const;
98
+
99
+ // Checks whether the type specified in the type URL matches the given type.
100
+ // A type is considered matching if its full name matches the full name after
101
+ // the last "/" in the type URL.
102
+ template <typename T>
103
+ bool Is() const {
104
+ return InternalIs(T::FullMessageName());
105
+ }
106
+
107
+ private:
108
+ void InternalPackFrom(const MessageLite& message,
109
+ StringPiece type_url_prefix,
110
+ StringPiece type_name);
111
+ bool InternalUnpackTo(StringPiece type_name,
112
+ MessageLite* message) const;
113
+ bool InternalIs(StringPiece type_name) const;
114
+
115
+ UrlType* type_url_;
116
+ ValueType* value_;
117
+
118
+ GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(AnyMetadata);
119
+ };
120
+
121
+ // Get the proto type name from Any::type_url value. For example, passing
122
+ // "type.googleapis.com/rpc.QueryOrigin" will return "rpc.QueryOrigin" in
123
+ // *full_type_name. Returns false if the type_url does not have a "/"
124
+ // in the type url separating the full type name.
125
+ //
126
+ // NOTE: this function is available publicly as:
127
+ // google::protobuf::Any() // static method on the generated message type.
128
+ bool ParseAnyTypeUrl(const std::string& type_url, std::string* full_type_name);
129
+
130
+ // Get the proto type name and prefix from Any::type_url value. For example,
131
+ // passing "type.googleapis.com/rpc.QueryOrigin" will return
132
+ // "type.googleapis.com/" in *url_prefix and "rpc.QueryOrigin" in
133
+ // *full_type_name. Returns false if the type_url does not have a "/" in the
134
+ // type url separating the full type name.
135
+ bool ParseAnyTypeUrl(const std::string& type_url, std::string* url_prefix,
136
+ std::string* full_type_name);
137
+
138
+ // See if message is of type google.protobuf.Any, if so, return the descriptors
139
+ // for "type_url" and "value" fields.
140
+ bool GetAnyFieldDescriptors(const Message& message,
141
+ const FieldDescriptor** type_url_field,
142
+ const FieldDescriptor** value_field);
143
+
144
+ } // namespace internal
145
+ } // namespace protobuf
146
+ } // namespace google
147
+
148
+ #include <google/protobuf/port_undef.inc>
149
+
150
+ #endif // GOOGLE_PROTOBUF_ANY_H__
151
+
152
+ #else
153
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
154
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/extension_set_inl.h ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Protocol Buffers - Google's data interchange format
3
+ // Copyright 2008 Google Inc. All rights reserved.
4
+ // https://developers.google.com/protocol-buffers/
5
+ //
6
+ // Redistribution and use in source and binary forms, with or without
7
+ // modification, are permitted provided that the following conditions are
8
+ // met:
9
+ //
10
+ // * Redistributions of source code must retain the above copyright
11
+ // notice, this list of conditions and the following disclaimer.
12
+ // * Redistributions in binary form must reproduce the above
13
+ // copyright notice, this list of conditions and the following disclaimer
14
+ // in the documentation and/or other materials provided with the
15
+ // distribution.
16
+ // * Neither the name of Google Inc. nor the names of its
17
+ // contributors may be used to endorse or promote products derived from
18
+ // this software without specific prior written permission.
19
+ //
20
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24
+ // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+
32
+ #ifndef GOOGLE_PROTOBUF_EXTENSION_SET_INL_H__
33
+ #define GOOGLE_PROTOBUF_EXTENSION_SET_INL_H__
34
+
35
+ #include <google/protobuf/parse_context.h>
36
+ #include <google/protobuf/extension_set.h>
37
+ #include <google/protobuf/metadata_lite.h>
38
+
39
+ namespace google {
40
+ namespace protobuf {
41
+ namespace internal {
42
+
43
+ template <typename T>
44
+ const char* ExtensionSet::ParseFieldWithExtensionInfo(
45
+ int number, bool was_packed_on_wire, const ExtensionInfo& extension,
46
+ InternalMetadata* metadata, const char* ptr, internal::ParseContext* ctx) {
47
+ if (was_packed_on_wire) {
48
+ switch (extension.type) {
49
+ #define HANDLE_TYPE(UPPERCASE, CPP_CAMELCASE) \
50
+ case WireFormatLite::TYPE_##UPPERCASE: \
51
+ return internal::Packed##CPP_CAMELCASE##Parser( \
52
+ MutableRawRepeatedField(number, extension.type, extension.is_packed, \
53
+ extension.descriptor), \
54
+ ptr, ctx);
55
+ HANDLE_TYPE(INT32, Int32);
56
+ HANDLE_TYPE(INT64, Int64);
57
+ HANDLE_TYPE(UINT32, UInt32);
58
+ HANDLE_TYPE(UINT64, UInt64);
59
+ HANDLE_TYPE(SINT32, SInt32);
60
+ HANDLE_TYPE(SINT64, SInt64);
61
+ HANDLE_TYPE(FIXED32, Fixed32);
62
+ HANDLE_TYPE(FIXED64, Fixed64);
63
+ HANDLE_TYPE(SFIXED32, SFixed32);
64
+ HANDLE_TYPE(SFIXED64, SFixed64);
65
+ HANDLE_TYPE(FLOAT, Float);
66
+ HANDLE_TYPE(DOUBLE, Double);
67
+ HANDLE_TYPE(BOOL, Bool);
68
+ #undef HANDLE_TYPE
69
+
70
+ case WireFormatLite::TYPE_ENUM:
71
+ return internal::PackedEnumParserArg<T>(
72
+ MutableRawRepeatedField(number, extension.type, extension.is_packed,
73
+ extension.descriptor),
74
+ ptr, ctx, extension.enum_validity_check.func,
75
+ extension.enum_validity_check.arg, metadata, number);
76
+ case WireFormatLite::TYPE_STRING:
77
+ case WireFormatLite::TYPE_BYTES:
78
+ case WireFormatLite::TYPE_GROUP:
79
+ case WireFormatLite::TYPE_MESSAGE:
80
+ GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed.";
81
+ break;
82
+ }
83
+ } else {
84
+ switch (extension.type) {
85
+ #define HANDLE_VARINT_TYPE(UPPERCASE, CPP_CAMELCASE) \
86
+ case WireFormatLite::TYPE_##UPPERCASE: { \
87
+ uint64 value; \
88
+ ptr = VarintParse(ptr, &value); \
89
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); \
90
+ if (extension.is_repeated) { \
91
+ Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \
92
+ extension.is_packed, value, extension.descriptor); \
93
+ } else { \
94
+ Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \
95
+ extension.descriptor); \
96
+ } \
97
+ } break
98
+
99
+ HANDLE_VARINT_TYPE(INT32, Int32);
100
+ HANDLE_VARINT_TYPE(INT64, Int64);
101
+ HANDLE_VARINT_TYPE(UINT32, UInt32);
102
+ HANDLE_VARINT_TYPE(UINT64, UInt64);
103
+ HANDLE_VARINT_TYPE(BOOL, Bool);
104
+ #undef HANDLE_VARINT_TYPE
105
+ #define HANDLE_SVARINT_TYPE(UPPERCASE, CPP_CAMELCASE, SIZE) \
106
+ case WireFormatLite::TYPE_##UPPERCASE: { \
107
+ uint64 val; \
108
+ ptr = VarintParse(ptr, &val); \
109
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr); \
110
+ auto value = WireFormatLite::ZigZagDecode##SIZE(val); \
111
+ if (extension.is_repeated) { \
112
+ Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \
113
+ extension.is_packed, value, extension.descriptor); \
114
+ } else { \
115
+ Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \
116
+ extension.descriptor); \
117
+ } \
118
+ } break
119
+
120
+ HANDLE_SVARINT_TYPE(SINT32, Int32, 32);
121
+ HANDLE_SVARINT_TYPE(SINT64, Int64, 64);
122
+ #undef HANDLE_SVARINT_TYPE
123
+ #define HANDLE_FIXED_TYPE(UPPERCASE, CPP_CAMELCASE, CPPTYPE) \
124
+ case WireFormatLite::TYPE_##UPPERCASE: { \
125
+ auto value = UnalignedLoad<CPPTYPE>(ptr); \
126
+ ptr += sizeof(CPPTYPE); \
127
+ if (extension.is_repeated) { \
128
+ Add##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, \
129
+ extension.is_packed, value, extension.descriptor); \
130
+ } else { \
131
+ Set##CPP_CAMELCASE(number, WireFormatLite::TYPE_##UPPERCASE, value, \
132
+ extension.descriptor); \
133
+ } \
134
+ } break
135
+
136
+ HANDLE_FIXED_TYPE(FIXED32, UInt32, uint32);
137
+ HANDLE_FIXED_TYPE(FIXED64, UInt64, uint64);
138
+ HANDLE_FIXED_TYPE(SFIXED32, Int32, int32);
139
+ HANDLE_FIXED_TYPE(SFIXED64, Int64, int64);
140
+ HANDLE_FIXED_TYPE(FLOAT, Float, float);
141
+ HANDLE_FIXED_TYPE(DOUBLE, Double, double);
142
+ #undef HANDLE_FIXED_TYPE
143
+
144
+ case WireFormatLite::TYPE_ENUM: {
145
+ uint64 val;
146
+ ptr = VarintParse(ptr, &val);
147
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
148
+ int value = val;
149
+
150
+ if (!extension.enum_validity_check.func(
151
+ extension.enum_validity_check.arg, value)) {
152
+ WriteVarint(number, val, metadata->mutable_unknown_fields<T>());
153
+ } else if (extension.is_repeated) {
154
+ AddEnum(number, WireFormatLite::TYPE_ENUM, extension.is_packed, value,
155
+ extension.descriptor);
156
+ } else {
157
+ SetEnum(number, WireFormatLite::TYPE_ENUM, value,
158
+ extension.descriptor);
159
+ }
160
+ break;
161
+ }
162
+
163
+ case WireFormatLite::TYPE_BYTES:
164
+ case WireFormatLite::TYPE_STRING: {
165
+ std::string* value =
166
+ extension.is_repeated
167
+ ? AddString(number, WireFormatLite::TYPE_STRING,
168
+ extension.descriptor)
169
+ : MutableString(number, WireFormatLite::TYPE_STRING,
170
+ extension.descriptor);
171
+ int size = ReadSize(&ptr);
172
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
173
+ return ctx->ReadString(ptr, size, value);
174
+ }
175
+
176
+ case WireFormatLite::TYPE_GROUP: {
177
+ MessageLite* value =
178
+ extension.is_repeated
179
+ ? AddMessage(number, WireFormatLite::TYPE_GROUP,
180
+ *extension.message_info.prototype,
181
+ extension.descriptor)
182
+ : MutableMessage(number, WireFormatLite::TYPE_GROUP,
183
+ *extension.message_info.prototype,
184
+ extension.descriptor);
185
+ uint32 tag = (number << 3) + WireFormatLite::WIRETYPE_START_GROUP;
186
+ return ctx->ParseGroup(value, ptr, tag);
187
+ }
188
+
189
+ case WireFormatLite::TYPE_MESSAGE: {
190
+ MessageLite* value =
191
+ extension.is_repeated
192
+ ? AddMessage(number, WireFormatLite::TYPE_MESSAGE,
193
+ *extension.message_info.prototype,
194
+ extension.descriptor)
195
+ : MutableMessage(number, WireFormatLite::TYPE_MESSAGE,
196
+ *extension.message_info.prototype,
197
+ extension.descriptor);
198
+ return ctx->ParseMessage(value, ptr);
199
+ }
200
+ }
201
+ }
202
+ return ptr;
203
+ }
204
+
205
+ template <typename Msg, typename T>
206
+ const char* ExtensionSet::ParseMessageSetItemTmpl(
207
+ const char* ptr, const Msg* containing_type,
208
+ internal::InternalMetadata* metadata, internal::ParseContext* ctx) {
209
+ std::string payload;
210
+ uint32 type_id = 0;
211
+ bool payload_read = false;
212
+ while (!ctx->Done(&ptr)) {
213
+ uint32 tag = static_cast<uint8>(*ptr++);
214
+ if (tag == WireFormatLite::kMessageSetTypeIdTag) {
215
+ uint64 tmp;
216
+ ptr = ParseBigVarint(ptr, &tmp);
217
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
218
+ type_id = tmp;
219
+ if (payload_read) {
220
+ ExtensionInfo extension;
221
+ bool was_packed_on_wire;
222
+ if (!FindExtension(2, type_id, containing_type, ctx, &extension,
223
+ &was_packed_on_wire)) {
224
+ WriteLengthDelimited(type_id, payload,
225
+ metadata->mutable_unknown_fields<T>());
226
+ } else {
227
+ MessageLite* value =
228
+ extension.is_repeated
229
+ ? AddMessage(type_id, WireFormatLite::TYPE_MESSAGE,
230
+ *extension.message_info.prototype,
231
+ extension.descriptor)
232
+ : MutableMessage(type_id, WireFormatLite::TYPE_MESSAGE,
233
+ *extension.message_info.prototype,
234
+ extension.descriptor);
235
+
236
+ const char* p;
237
+ // We can't use regular parse from string as we have to track
238
+ // proper recursion depth and descriptor pools.
239
+ ParseContext tmp_ctx(ctx->depth(), false, &p, payload);
240
+ tmp_ctx.data().pool = ctx->data().pool;
241
+ tmp_ctx.data().factory = ctx->data().factory;
242
+ GOOGLE_PROTOBUF_PARSER_ASSERT(value->_InternalParse(p, &tmp_ctx) &&
243
+ tmp_ctx.EndedAtLimit());
244
+ }
245
+ type_id = 0;
246
+ }
247
+ } else if (tag == WireFormatLite::kMessageSetMessageTag) {
248
+ if (type_id != 0) {
249
+ ptr = ParseFieldMaybeLazily(static_cast<uint64>(type_id) * 8 + 2, ptr,
250
+ containing_type, metadata, ctx);
251
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr != nullptr);
252
+ type_id = 0;
253
+ } else {
254
+ int32 size = ReadSize(&ptr);
255
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
256
+ ptr = ctx->ReadString(ptr, size, &payload);
257
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
258
+ payload_read = true;
259
+ }
260
+ } else {
261
+ ptr = ReadTag(ptr - 1, &tag);
262
+ if (tag == 0 || (tag & 7) == 4) {
263
+ ctx->SetLastTag(tag);
264
+ return ptr;
265
+ }
266
+ ptr = ParseField(tag, ptr, containing_type, metadata, ctx);
267
+ GOOGLE_PROTOBUF_PARSER_ASSERT(ptr);
268
+ }
269
+ }
270
+ return ptr;
271
+ }
272
+
273
+ } // namespace internal
274
+ } // namespace protobuf
275
+ } // namespace google
276
+
277
+ #endif // GOOGLE_PROTOBUF_EXTENSION_SET_INL_H__
278
+
279
+ #else
280
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
281
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/field_mask.pb.h ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ // source: google/protobuf/field_mask.proto
4
+
5
+ #ifndef GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ffield_5fmask_2eproto
6
+ #define GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ffield_5fmask_2eproto
7
+
8
+ #include <limits>
9
+ #include <string>
10
+
11
+ #include <google/protobuf/port_def.inc>
12
+ #if PROTOBUF_VERSION < 3013000
13
+ #error This file was generated by a newer version of protoc which is
14
+ #error incompatible with your Protocol Buffer headers. Please update
15
+ #error your headers.
16
+ #endif
17
+ #if 3013000 < PROTOBUF_MIN_PROTOC_VERSION
18
+ #error This file was generated by an older version of protoc which is
19
+ #error incompatible with your Protocol Buffer headers. Please
20
+ #error regenerate this file with a newer version of protoc.
21
+ #endif
22
+
23
+ #include <google/protobuf/port_undef.inc>
24
+ #include <google/protobuf/io/coded_stream.h>
25
+ #include <google/protobuf/arena.h>
26
+ #include <google/protobuf/arenastring.h>
27
+ #include <google/protobuf/generated_message_table_driven.h>
28
+ #include <google/protobuf/generated_message_util.h>
29
+ #include <google/protobuf/inlined_string_field.h>
30
+ #include <google/protobuf/metadata_lite.h>
31
+ #include <google/protobuf/generated_message_reflection.h>
32
+ #include <google/protobuf/message.h>
33
+ #include <google/protobuf/repeated_field.h> // IWYU pragma: export
34
+ #include <google/protobuf/extension_set.h> // IWYU pragma: export
35
+ #include <google/protobuf/unknown_field_set.h>
36
+ // @@protoc_insertion_point(includes)
37
+ #include <google/protobuf/port_def.inc>
38
+ #define PROTOBUF_INTERNAL_EXPORT_google_2fprotobuf_2ffield_5fmask_2eproto PROTOBUF_EXPORT
39
+ PROTOBUF_NAMESPACE_OPEN
40
+ namespace internal {
41
+ class AnyMetadata;
42
+ } // namespace internal
43
+ PROTOBUF_NAMESPACE_CLOSE
44
+
45
+ // Internal implementation detail -- do not use these members.
46
+ struct PROTOBUF_EXPORT TableStruct_google_2fprotobuf_2ffield_5fmask_2eproto {
47
+ static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTableField entries[]
48
+ PROTOBUF_SECTION_VARIABLE(protodesc_cold);
49
+ static const ::PROTOBUF_NAMESPACE_ID::internal::AuxiliaryParseTableField aux[]
50
+ PROTOBUF_SECTION_VARIABLE(protodesc_cold);
51
+ static const ::PROTOBUF_NAMESPACE_ID::internal::ParseTable schema[1]
52
+ PROTOBUF_SECTION_VARIABLE(protodesc_cold);
53
+ static const ::PROTOBUF_NAMESPACE_ID::internal::FieldMetadata field_metadata[];
54
+ static const ::PROTOBUF_NAMESPACE_ID::internal::SerializationTable serialization_table[];
55
+ static const ::PROTOBUF_NAMESPACE_ID::uint32 offsets[];
56
+ };
57
+ extern PROTOBUF_EXPORT const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_google_2fprotobuf_2ffield_5fmask_2eproto;
58
+ PROTOBUF_NAMESPACE_OPEN
59
+ class FieldMask;
60
+ class FieldMaskDefaultTypeInternal;
61
+ PROTOBUF_EXPORT extern FieldMaskDefaultTypeInternal _FieldMask_default_instance_;
62
+ PROTOBUF_NAMESPACE_CLOSE
63
+ PROTOBUF_NAMESPACE_OPEN
64
+ template<> PROTOBUF_EXPORT PROTOBUF_NAMESPACE_ID::FieldMask* Arena::CreateMaybeMessage<PROTOBUF_NAMESPACE_ID::FieldMask>(Arena*);
65
+ PROTOBUF_NAMESPACE_CLOSE
66
+ PROTOBUF_NAMESPACE_OPEN
67
+
68
+ // ===================================================================
69
+
70
+ class PROTOBUF_EXPORT FieldMask PROTOBUF_FINAL :
71
+ public ::PROTOBUF_NAMESPACE_ID::Message /* @@protoc_insertion_point(class_definition:google.protobuf.FieldMask) */ {
72
+ public:
73
+ inline FieldMask() : FieldMask(nullptr) {}
74
+ virtual ~FieldMask();
75
+
76
+ FieldMask(const FieldMask& from);
77
+ FieldMask(FieldMask&& from) noexcept
78
+ : FieldMask() {
79
+ *this = ::std::move(from);
80
+ }
81
+
82
+ inline FieldMask& operator=(const FieldMask& from) {
83
+ CopyFrom(from);
84
+ return *this;
85
+ }
86
+ inline FieldMask& operator=(FieldMask&& from) noexcept {
87
+ if (GetArena() == from.GetArena()) {
88
+ if (this != &from) InternalSwap(&from);
89
+ } else {
90
+ CopyFrom(from);
91
+ }
92
+ return *this;
93
+ }
94
+
95
+ static const ::PROTOBUF_NAMESPACE_ID::Descriptor* descriptor() {
96
+ return GetDescriptor();
97
+ }
98
+ static const ::PROTOBUF_NAMESPACE_ID::Descriptor* GetDescriptor() {
99
+ return GetMetadataStatic().descriptor;
100
+ }
101
+ static const ::PROTOBUF_NAMESPACE_ID::Reflection* GetReflection() {
102
+ return GetMetadataStatic().reflection;
103
+ }
104
+ static const FieldMask& default_instance();
105
+
106
+ static void InitAsDefaultInstance(); // FOR INTERNAL USE ONLY
107
+ static inline const FieldMask* internal_default_instance() {
108
+ return reinterpret_cast<const FieldMask*>(
109
+ &_FieldMask_default_instance_);
110
+ }
111
+ static constexpr int kIndexInFileMessages =
112
+ 0;
113
+
114
+ friend void swap(FieldMask& a, FieldMask& b) {
115
+ a.Swap(&b);
116
+ }
117
+ inline void Swap(FieldMask* other) {
118
+ if (other == this) return;
119
+ if (GetArena() == other->GetArena()) {
120
+ InternalSwap(other);
121
+ } else {
122
+ ::PROTOBUF_NAMESPACE_ID::internal::GenericSwap(this, other);
123
+ }
124
+ }
125
+ void UnsafeArenaSwap(FieldMask* other) {
126
+ if (other == this) return;
127
+ GOOGLE_DCHECK(GetArena() == other->GetArena());
128
+ InternalSwap(other);
129
+ }
130
+
131
+ // implements Message ----------------------------------------------
132
+
133
+ inline FieldMask* New() const final {
134
+ return CreateMaybeMessage<FieldMask>(nullptr);
135
+ }
136
+
137
+ FieldMask* New(::PROTOBUF_NAMESPACE_ID::Arena* arena) const final {
138
+ return CreateMaybeMessage<FieldMask>(arena);
139
+ }
140
+ void CopyFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
141
+ void MergeFrom(const ::PROTOBUF_NAMESPACE_ID::Message& from) final;
142
+ void CopyFrom(const FieldMask& from);
143
+ void MergeFrom(const FieldMask& from);
144
+ PROTOBUF_ATTRIBUTE_REINITIALIZES void Clear() final;
145
+ bool IsInitialized() const final;
146
+
147
+ size_t ByteSizeLong() const final;
148
+ const char* _InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) final;
149
+ ::PROTOBUF_NAMESPACE_ID::uint8* _InternalSerialize(
150
+ ::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const final;
151
+ int GetCachedSize() const final { return _cached_size_.Get(); }
152
+
153
+ private:
154
+ inline void SharedCtor();
155
+ inline void SharedDtor();
156
+ void SetCachedSize(int size) const final;
157
+ void InternalSwap(FieldMask* other);
158
+ friend class ::PROTOBUF_NAMESPACE_ID::internal::AnyMetadata;
159
+ static ::PROTOBUF_NAMESPACE_ID::StringPiece FullMessageName() {
160
+ return "google.protobuf.FieldMask";
161
+ }
162
+ protected:
163
+ explicit FieldMask(::PROTOBUF_NAMESPACE_ID::Arena* arena);
164
+ private:
165
+ static void ArenaDtor(void* object);
166
+ inline void RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena* arena);
167
+ public:
168
+
169
+ ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadata() const final;
170
+ private:
171
+ static ::PROTOBUF_NAMESPACE_ID::Metadata GetMetadataStatic() {
172
+ ::PROTOBUF_NAMESPACE_ID::internal::AssignDescriptors(&::descriptor_table_google_2fprotobuf_2ffield_5fmask_2eproto);
173
+ return ::descriptor_table_google_2fprotobuf_2ffield_5fmask_2eproto.file_level_metadata[kIndexInFileMessages];
174
+ }
175
+
176
+ public:
177
+
178
+ // nested types ----------------------------------------------------
179
+
180
+ // accessors -------------------------------------------------------
181
+
182
+ enum : int {
183
+ kPathsFieldNumber = 1,
184
+ };
185
+ // repeated string paths = 1;
186
+ int paths_size() const;
187
+ private:
188
+ int _internal_paths_size() const;
189
+ public:
190
+ void clear_paths();
191
+ const std::string& paths(int index) const;
192
+ std::string* mutable_paths(int index);
193
+ void set_paths(int index, const std::string& value);
194
+ void set_paths(int index, std::string&& value);
195
+ void set_paths(int index, const char* value);
196
+ void set_paths(int index, const char* value, size_t size);
197
+ std::string* add_paths();
198
+ void add_paths(const std::string& value);
199
+ void add_paths(std::string&& value);
200
+ void add_paths(const char* value);
201
+ void add_paths(const char* value, size_t size);
202
+ const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField<std::string>& paths() const;
203
+ ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField<std::string>* mutable_paths();
204
+ private:
205
+ const std::string& _internal_paths(int index) const;
206
+ std::string* _internal_add_paths();
207
+ public:
208
+
209
+ // @@protoc_insertion_point(class_scope:google.protobuf.FieldMask)
210
+ private:
211
+ class _Internal;
212
+
213
+ template <typename T> friend class ::PROTOBUF_NAMESPACE_ID::Arena::InternalHelper;
214
+ typedef void InternalArenaConstructable_;
215
+ typedef void DestructorSkippable_;
216
+ ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField<std::string> paths_;
217
+ mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
218
+ friend struct ::TableStruct_google_2fprotobuf_2ffield_5fmask_2eproto;
219
+ };
220
+ // ===================================================================
221
+
222
+
223
+ // ===================================================================
224
+
225
+ #ifdef __GNUC__
226
+ #pragma GCC diagnostic push
227
+ #pragma GCC diagnostic ignored "-Wstrict-aliasing"
228
+ #endif // __GNUC__
229
+ // FieldMask
230
+
231
+ // repeated string paths = 1;
232
+ inline int FieldMask::_internal_paths_size() const {
233
+ return paths_.size();
234
+ }
235
+ inline int FieldMask::paths_size() const {
236
+ return _internal_paths_size();
237
+ }
238
+ inline void FieldMask::clear_paths() {
239
+ paths_.Clear();
240
+ }
241
+ inline std::string* FieldMask::add_paths() {
242
+ // @@protoc_insertion_point(field_add_mutable:google.protobuf.FieldMask.paths)
243
+ return _internal_add_paths();
244
+ }
245
+ inline const std::string& FieldMask::_internal_paths(int index) const {
246
+ return paths_.Get(index);
247
+ }
248
+ inline const std::string& FieldMask::paths(int index) const {
249
+ // @@protoc_insertion_point(field_get:google.protobuf.FieldMask.paths)
250
+ return _internal_paths(index);
251
+ }
252
+ inline std::string* FieldMask::mutable_paths(int index) {
253
+ // @@protoc_insertion_point(field_mutable:google.protobuf.FieldMask.paths)
254
+ return paths_.Mutable(index);
255
+ }
256
+ inline void FieldMask::set_paths(int index, const std::string& value) {
257
+ // @@protoc_insertion_point(field_set:google.protobuf.FieldMask.paths)
258
+ paths_.Mutable(index)->assign(value);
259
+ }
260
+ inline void FieldMask::set_paths(int index, std::string&& value) {
261
+ // @@protoc_insertion_point(field_set:google.protobuf.FieldMask.paths)
262
+ paths_.Mutable(index)->assign(std::move(value));
263
+ }
264
+ inline void FieldMask::set_paths(int index, const char* value) {
265
+ GOOGLE_DCHECK(value != nullptr);
266
+ paths_.Mutable(index)->assign(value);
267
+ // @@protoc_insertion_point(field_set_char:google.protobuf.FieldMask.paths)
268
+ }
269
+ inline void FieldMask::set_paths(int index, const char* value, size_t size) {
270
+ paths_.Mutable(index)->assign(
271
+ reinterpret_cast<const char*>(value), size);
272
+ // @@protoc_insertion_point(field_set_pointer:google.protobuf.FieldMask.paths)
273
+ }
274
+ inline std::string* FieldMask::_internal_add_paths() {
275
+ return paths_.Add();
276
+ }
277
+ inline void FieldMask::add_paths(const std::string& value) {
278
+ paths_.Add()->assign(value);
279
+ // @@protoc_insertion_point(field_add:google.protobuf.FieldMask.paths)
280
+ }
281
+ inline void FieldMask::add_paths(std::string&& value) {
282
+ paths_.Add(std::move(value));
283
+ // @@protoc_insertion_point(field_add:google.protobuf.FieldMask.paths)
284
+ }
285
+ inline void FieldMask::add_paths(const char* value) {
286
+ GOOGLE_DCHECK(value != nullptr);
287
+ paths_.Add()->assign(value);
288
+ // @@protoc_insertion_point(field_add_char:google.protobuf.FieldMask.paths)
289
+ }
290
+ inline void FieldMask::add_paths(const char* value, size_t size) {
291
+ paths_.Add()->assign(reinterpret_cast<const char*>(value), size);
292
+ // @@protoc_insertion_point(field_add_pointer:google.protobuf.FieldMask.paths)
293
+ }
294
+ inline const ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField<std::string>&
295
+ FieldMask::paths() const {
296
+ // @@protoc_insertion_point(field_list:google.protobuf.FieldMask.paths)
297
+ return paths_;
298
+ }
299
+ inline ::PROTOBUF_NAMESPACE_ID::RepeatedPtrField<std::string>*
300
+ FieldMask::mutable_paths() {
301
+ // @@protoc_insertion_point(field_mutable_list:google.protobuf.FieldMask.paths)
302
+ return &paths_;
303
+ }
304
+
305
+ #ifdef __GNUC__
306
+ #pragma GCC diagnostic pop
307
+ #endif // __GNUC__
308
+
309
+ // @@protoc_insertion_point(namespace_scope)
310
+
311
+ PROTOBUF_NAMESPACE_CLOSE
312
+
313
+ // @@protoc_insertion_point(global_scope)
314
+
315
+ #include <google/protobuf/port_undef.inc>
316
+ #endif // GOOGLE_PROTOBUF_INCLUDED_GOOGLE_PROTOBUF_INCLUDED_google_2fprotobuf_2ffield_5fmask_2eproto
317
+
318
+ #else
319
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
320
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
Prism/LLaDA/LLaDA_Prism/.venv/lib/python3.12/site-packages/torch/include/google/protobuf/generated_message_table_driven.h ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
2
+ // Protocol Buffers - Google's data interchange format
3
+ // Copyright 2008 Google Inc. All rights reserved.
4
+ // https://developers.google.com/protocol-buffers/
5
+ //
6
+ // Redistribution and use in source and binary forms, with or without
7
+ // modification, are permitted provided that the following conditions are
8
+ // met:
9
+ //
10
+ // * Redistributions of source code must retain the above copyright
11
+ // notice, this list of conditions and the following disclaimer.
12
+ // * Redistributions in binary form must reproduce the above
13
+ // copyright notice, this list of conditions and the following disclaimer
14
+ // in the documentation and/or other materials provided with the
15
+ // distribution.
16
+ // * Neither the name of Google Inc. nor the names of its
17
+ // contributors may be used to endorse or promote products derived from
18
+ // this software without specific prior written permission.
19
+ //
20
+ // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21
+ // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22
+ // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23
+ // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24
+ // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25
+ // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26
+ // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27
+ // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28
+ // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+
32
+ #ifndef GOOGLE_PROTOBUF_GENERATED_MESSAGE_TABLE_DRIVEN_H__
33
+ #define GOOGLE_PROTOBUF_GENERATED_MESSAGE_TABLE_DRIVEN_H__
34
+
35
+ #include <google/protobuf/map.h>
36
+ #include <google/protobuf/map_entry_lite.h>
37
+ #include <google/protobuf/map_field_lite.h>
38
+ #include <google/protobuf/message_lite.h>
39
+ #include <google/protobuf/wire_format_lite.h>
40
+
41
+ // We require C++11 and Clang to use constexpr for variables, as GCC 4.8
42
+ // requires constexpr to be consistent between declarations of variables
43
+ // unnecessarily (see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=58541).
44
+ // VS 2017 Update 3 also supports this usage of constexpr.
45
+ #if defined(__clang__) || (defined(_MSC_VER) && _MSC_VER >= 1911)
46
+ #define PROTOBUF_CONSTEXPR_VAR constexpr
47
+ #else // !__clang__
48
+ #define PROTOBUF_CONSTEXPR_VAR
49
+ #endif // !_clang
50
+
51
+ #ifdef SWIG
52
+ #error "You cannot SWIG proto headers"
53
+ #endif
54
+
55
+ #include <google/protobuf/port_def.inc>
56
+
57
+ namespace google {
58
+ namespace protobuf {
59
+ namespace internal {
60
+
61
+ // Processing-type masks.
62
+ static constexpr const unsigned char kOneofMask = 0x40;
63
+ static constexpr const unsigned char kRepeatedMask = 0x20;
64
+ // Mask for the raw type: either a WireFormatLite::FieldType or one of the
65
+ // ProcessingTypes below, without the oneof or repeated flag.
66
+ static constexpr const unsigned char kTypeMask = 0x1f;
67
+
68
+ // Wire type masks.
69
+ static constexpr const unsigned char kNotPackedMask = 0x10;
70
+ static constexpr const unsigned char kInvalidMask = 0x20;
71
+
72
+ enum ProcessingTypes {
73
+ TYPE_STRING_CORD = 19,
74
+ TYPE_STRING_STRING_PIECE = 20,
75
+ TYPE_BYTES_CORD = 21,
76
+ TYPE_BYTES_STRING_PIECE = 22,
77
+ TYPE_STRING_INLINED = 23,
78
+ TYPE_BYTES_INLINED = 24,
79
+ TYPE_MAP = 25,
80
+ };
81
+
82
+ static_assert(TYPE_MAP < kRepeatedMask, "Invalid enum");
83
+
84
+ struct PROTOBUF_EXPORT FieldMetadata {
85
+ uint32 offset; // offset of this field in the struct
86
+ uint32 tag; // field * 8 + wire_type
87
+ // byte offset * 8 + bit_offset;
88
+ // if the high bit is set then this is the byte offset of the oneof_case
89
+ // for this field.
90
+ uint32 has_offset;
91
+ uint32 type; // the type of this field.
92
+ const void* ptr; // auxiliary data
93
+
94
+ // From the serializer point of view each fundamental type can occur in
95
+ // 4 different ways. For simplicity we treat all combinations as a cartesion
96
+ // product although not all combinations are allowed.
97
+ enum FieldTypeClass {
98
+ kPresence,
99
+ kNoPresence,
100
+ kRepeated,
101
+ kPacked,
102
+ kOneOf,
103
+ kNumTypeClasses // must be last enum
104
+ };
105
+ // C++ protobuf has 20 fundamental types, were we added Cord and StringPiece
106
+ // and also distinquish the same types if they have different wire format.
107
+ enum {
108
+ kCordType = 19,
109
+ kStringPieceType = 20,
110
+ kInlinedType = 21,
111
+ kNumTypes = 21,
112
+ kSpecial = kNumTypes * kNumTypeClasses,
113
+ };
114
+
115
+ static int CalculateType(int fundamental_type, FieldTypeClass type_class);
116
+ };
117
+
118
+ // TODO(ckennelly): Add a static assertion to ensure that these masks do not
119
+ // conflict with wiretypes.
120
+
121
+ // ParseTableField is kept small to help simplify instructions for computing
122
+ // offsets, as we will always need this information to parse a field.
123
+ // Additional data, needed for some types, is stored in
124
+ // AuxiliaryParseTableField.
125
+ struct ParseTableField {
126
+ uint32 offset;
127
+ // The presence_index ordinarily represents a has_bit index, but for fields
128
+ // inside a oneof it represents the index in _oneof_case_.
129
+ uint32 presence_index;
130
+ unsigned char normal_wiretype;
131
+ unsigned char packed_wiretype;
132
+
133
+ // processing_type is given by:
134
+ // (FieldDescriptor->type() << 1) | FieldDescriptor->is_packed()
135
+ unsigned char processing_type;
136
+
137
+ unsigned char tag_size;
138
+ };
139
+
140
+ struct ParseTable;
141
+
142
+ union AuxiliaryParseTableField {
143
+ typedef bool (*EnumValidator)(int);
144
+
145
+ // Enums
146
+ struct enum_aux {
147
+ EnumValidator validator;
148
+ };
149
+ enum_aux enums;
150
+ // Group, messages
151
+ struct message_aux {
152
+ // ExplicitlyInitialized<T> -> T requires a reinterpret_cast, which prevents
153
+ // the tables from being constructed as a constexpr. We use void to avoid
154
+ // the cast.
155
+ const void* default_message_void;
156
+ const MessageLite* default_message() const {
157
+ return static_cast<const MessageLite*>(default_message_void);
158
+ }
159
+ };
160
+ message_aux messages;
161
+ // Strings
162
+ struct string_aux {
163
+ const void* default_ptr;
164
+ const char* field_name;
165
+ };
166
+ string_aux strings;
167
+
168
+ struct map_aux {
169
+ bool (*parse_map)(io::CodedInputStream*, void*);
170
+ };
171
+ map_aux maps;
172
+
173
+ AuxiliaryParseTableField() = default;
174
+ constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::enum_aux e)
175
+ : enums(e) {}
176
+ constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::message_aux m)
177
+ : messages(m) {}
178
+ constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::string_aux s)
179
+ : strings(s) {}
180
+ constexpr AuxiliaryParseTableField(AuxiliaryParseTableField::map_aux m)
181
+ : maps(m) {}
182
+ };
183
+
184
+ struct ParseTable {
185
+ const ParseTableField* fields;
186
+ const AuxiliaryParseTableField* aux;
187
+ int max_field_number;
188
+ // TODO(ckennelly): Do something with this padding.
189
+
190
+ // TODO(ckennelly): Vet these for sign extension.
191
+ int64 has_bits_offset;
192
+ int64 oneof_case_offset;
193
+ int64 extension_offset;
194
+ int64 arena_offset;
195
+
196
+ // ExplicitlyInitialized<T> -> T requires a reinterpret_cast, which prevents
197
+ // the tables from being constructed as a constexpr. We use void to avoid
198
+ // the cast.
199
+ const void* default_instance_void;
200
+ const MessageLite* default_instance() const {
201
+ return static_cast<const MessageLite*>(default_instance_void);
202
+ }
203
+
204
+ bool unknown_field_set;
205
+ };
206
+
207
+ static_assert(sizeof(ParseTableField) <= 16, "ParseTableField is too large");
208
+ // The tables must be composed of POD components to ensure link-time
209
+ // initialization.
210
+ static_assert(std::is_pod<ParseTableField>::value, "");
211
+ static_assert(std::is_pod<AuxiliaryParseTableField>::value, "");
212
+ static_assert(std::is_pod<AuxiliaryParseTableField::enum_aux>::value, "");
213
+ static_assert(std::is_pod<AuxiliaryParseTableField::message_aux>::value, "");
214
+ static_assert(std::is_pod<AuxiliaryParseTableField::string_aux>::value, "");
215
+ static_assert(std::is_pod<ParseTable>::value, "");
216
+
217
+ // TODO(ckennelly): Consolidate these implementations into a single one, using
218
+ // dynamic dispatch to the appropriate unknown field handler.
219
+ bool MergePartialFromCodedStream(MessageLite* msg, const ParseTable& table,
220
+ io::CodedInputStream* input);
221
+ bool MergePartialFromCodedStreamLite(MessageLite* msg, const ParseTable& table,
222
+ io::CodedInputStream* input);
223
+
224
+ template <typename Entry>
225
+ bool ParseMap(io::CodedInputStream* input, void* map_field) {
226
+ typedef typename MapEntryToMapField<Entry>::MapFieldType MapFieldType;
227
+ typedef Map<typename Entry::EntryKeyType, typename Entry::EntryValueType>
228
+ MapType;
229
+ typedef typename Entry::template Parser<MapFieldType, MapType> ParserType;
230
+
231
+ ParserType parser(static_cast<MapFieldType*>(map_field));
232
+ return WireFormatLite::ReadMessageNoVirtual(input, &parser);
233
+ }
234
+
235
+ struct SerializationTable {
236
+ int num_fields;
237
+ const FieldMetadata* field_table;
238
+ };
239
+
240
+ PROTOBUF_EXPORT void SerializeInternal(const uint8* base,
241
+ const FieldMetadata* table,
242
+ int32 num_fields,
243
+ io::CodedOutputStream* output);
244
+
245
+ inline void TableSerialize(const MessageLite& msg,
246
+ const SerializationTable* table,
247
+ io::CodedOutputStream* output) {
248
+ const FieldMetadata* field_table = table->field_table;
249
+ int num_fields = table->num_fields - 1;
250
+ const uint8* base = reinterpret_cast<const uint8*>(&msg);
251
+ // TODO(gerbens) This skips the first test if we could use the fast
252
+ // array serialization path, we should make this
253
+ // int cached_size =
254
+ // *reinterpret_cast<const int32*>(base + field_table->offset);
255
+ // SerializeWithCachedSize(msg, field_table + 1, num_fields, cached_size, ...)
256
+ // But we keep conformance with the old way for now.
257
+ SerializeInternal(base, field_table + 1, num_fields, output);
258
+ }
259
+
260
+ uint8* SerializeInternalToArray(const uint8* base, const FieldMetadata* table,
261
+ int32 num_fields, bool is_deterministic,
262
+ uint8* buffer);
263
+
264
+ inline uint8* TableSerializeToArray(const MessageLite& msg,
265
+ const SerializationTable* table,
266
+ bool is_deterministic, uint8* buffer) {
267
+ const uint8* base = reinterpret_cast<const uint8*>(&msg);
268
+ const FieldMetadata* field_table = table->field_table + 1;
269
+ int num_fields = table->num_fields - 1;
270
+ return SerializeInternalToArray(base, field_table, num_fields,
271
+ is_deterministic, buffer);
272
+ }
273
+
274
+ template <typename T>
275
+ struct CompareHelper {
276
+ bool operator()(const T& a, const T& b) const { return a < b; }
277
+ };
278
+
279
+ template <>
280
+ struct CompareHelper<ArenaStringPtr> {
281
+ bool operator()(const ArenaStringPtr& a, const ArenaStringPtr& b) const {
282
+ return a.Get() < b.Get();
283
+ }
284
+ };
285
+
286
+ struct CompareMapKey {
287
+ template <typename T>
288
+ bool operator()(const MapEntryHelper<T>& a,
289
+ const MapEntryHelper<T>& b) const {
290
+ return Compare(a.key_, b.key_);
291
+ }
292
+ template <typename T>
293
+ bool Compare(const T& a, const T& b) const {
294
+ return CompareHelper<T>()(a, b);
295
+ }
296
+ };
297
+
298
+ template <typename MapFieldType, const SerializationTable* table>
299
+ void MapFieldSerializer(const uint8* base, uint32 offset, uint32 tag,
300
+ uint32 has_offset, io::CodedOutputStream* output) {
301
+ typedef MapEntryHelper<typename MapFieldType::EntryTypeTrait> Entry;
302
+ typedef typename MapFieldType::MapType::const_iterator Iter;
303
+
304
+ const MapFieldType& map_field =
305
+ *reinterpret_cast<const MapFieldType*>(base + offset);
306
+ const SerializationTable* t =
307
+ table +
308
+ has_offset; // has_offset is overloaded for maps to mean table offset
309
+ if (!output->IsSerializationDeterministic()) {
310
+ for (Iter it = map_field.GetMap().begin(); it != map_field.GetMap().end();
311
+ ++it) {
312
+ Entry map_entry(*it);
313
+ output->WriteVarint32(tag);
314
+ output->WriteVarint32(map_entry._cached_size_);
315
+ SerializeInternal(reinterpret_cast<const uint8*>(&map_entry),
316
+ t->field_table, t->num_fields, output);
317
+ }
318
+ } else {
319
+ std::vector<Entry> v;
320
+ for (Iter it = map_field.GetMap().begin(); it != map_field.GetMap().end();
321
+ ++it) {
322
+ v.push_back(Entry(*it));
323
+ }
324
+ std::sort(v.begin(), v.end(), CompareMapKey());
325
+ for (int i = 0; i < v.size(); i++) {
326
+ output->WriteVarint32(tag);
327
+ output->WriteVarint32(v[i]._cached_size_);
328
+ SerializeInternal(reinterpret_cast<const uint8*>(&v[i]), t->field_table,
329
+ t->num_fields, output);
330
+ }
331
+ }
332
+ }
333
+
334
+ } // namespace internal
335
+ } // namespace protobuf
336
+ } // namespace google
337
+
338
+ #include <google/protobuf/port_undef.inc>
339
+
340
+ #endif // GOOGLE_PROTOBUF_GENERATED_MESSAGE_TABLE_DRIVEN_H__
341
+
342
+ #else
343
+ #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
344
+ #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)