davidtran999 commited on
Commit
4ea4bb7
·
verified ·
1 Parent(s): fb65c6a

Upload backend/venv/lib/python3.10/site-packages/threadpoolctl.py with huggingface_hub

Browse files
backend/venv/lib/python3.10/site-packages/threadpoolctl.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """threadpoolctl
2
+
3
+ This module provides utilities to introspect native libraries that relies on
4
+ thread pools (notably BLAS and OpenMP implementations) and dynamically set the
5
+ maximal number of threads they can use.
6
+ """
7
+
8
+ # License: BSD 3-Clause
9
+
10
+ # The code to introspect dynamically loaded libraries on POSIX systems is
11
+ # adapted from code by Intel developer @anton-malakhov available at
12
+ # https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation)
13
+ # and also published under the BSD 3-Clause license
14
+ import os
15
+ import re
16
+ import sys
17
+ import ctypes
18
+ import itertools
19
+ import textwrap
20
+ from typing import final
21
+ import warnings
22
+ from ctypes.util import find_library
23
+ from abc import ABC, abstractmethod
24
+ from functools import lru_cache
25
+ from contextlib import ContextDecorator
26
+
27
+ __version__ = "3.6.0"
28
+ __all__ = [
29
+ "threadpool_limits",
30
+ "threadpool_info",
31
+ "ThreadpoolController",
32
+ "LibController",
33
+ "register",
34
+ ]
35
+
36
+
37
+ # One can get runtime errors or even segfaults due to multiple OpenMP libraries
38
+ # loaded simultaneously which can happen easily in Python when importing and
39
+ # using compiled extensions built with different compilers and therefore
40
+ # different OpenMP runtimes in the same program. In particular libiomp (used by
41
+ # Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for
42
+ # instance when calling BLAS inside a prange. Setting the following environment
43
+ # variable allows multiple OpenMP libraries to be loaded. It should not degrade
44
+ # performances since we manually take care of potential over-subscription
45
+ # performance issues, in sections of the code where nested OpenMP loops can
46
+ # happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily
47
+ # disable it while under the scope of the outer OpenMP parallel section.
48
+ os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")
49
+
50
+ # Structure to cast the info on dynamically loaded library. See
51
+ # https://linux.die.net/man/3/dl_iterate_phdr for more details.
52
+ _SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32
53
+ _SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16
54
+
55
+
56
+ class _dl_phdr_info(ctypes.Structure):
57
+ _fields_ = [
58
+ ("dlpi_addr", _SYSTEM_UINT), # Base address of object
59
+ ("dlpi_name", ctypes.c_char_p), # path to the library
60
+ ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers
61
+ ("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr
62
+ ]
63
+
64
+
65
+ # The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows.
66
+ try:
67
+ _RTLD_NOLOAD = os.RTLD_NOLOAD
68
+ except AttributeError:
69
+ _RTLD_NOLOAD = ctypes.DEFAULT_MODE
70
+
71
+
72
+ class LibController(ABC):
73
+ """Abstract base class for the individual library controllers
74
+
75
+ A library controller must expose the following class attributes:
76
+ - user_api : str
77
+ Usually the name of the library or generic specification the library
78
+ implements, e.g. "blas" is a specification with different implementations.
79
+ - internal_api : str
80
+ Usually the name of the library or concrete implementation of some
81
+ specification, e.g. "openblas" is an implementation of the "blas"
82
+ specification.
83
+ - filename_prefixes : tuple
84
+ Possible prefixes of the shared library's filename that allow to
85
+ identify the library. e.g. "libopenblas" for libopenblas.so.
86
+
87
+ and implement the following methods: `get_num_threads`, `set_num_threads` and
88
+ `get_version`.
89
+
90
+ Threadpoolctl loops through all the loaded shared libraries and tries to match
91
+ the filename of each library with the `filename_prefixes`. If a match is found, a
92
+ controller is instantiated and a handler to the library is stored in the `dynlib`
93
+ attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols
94
+ of the shared library to implement the above methods.
95
+
96
+ The following information will be exposed in the info dictionary:
97
+ - user_api : standardized API, if any, or a copy of internal_api.
98
+ - internal_api : implementation-specific API.
99
+ - num_threads : the current thread limit.
100
+ - prefix : prefix of the shared library's filename.
101
+ - filepath : path to the loaded shared library.
102
+ - version : version of the library (if available).
103
+
104
+ In addition, each library controller may expose internal API specific entries. They
105
+ must be set as attributes in the `set_additional_attributes` method.
106
+ """
107
+
108
+ @final
109
+ def __init__(self, *, filepath=None, prefix=None, parent=None):
110
+ """This is not meant to be overriden by subclasses."""
111
+ self.parent = parent
112
+ self.prefix = prefix
113
+ self.filepath = filepath
114
+ self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD)
115
+ self._symbol_prefix, self._symbol_suffix = self._find_affixes()
116
+ self.version = self.get_version()
117
+ self.set_additional_attributes()
118
+
119
+ def info(self):
120
+ """Return relevant info wrapped in a dict"""
121
+ hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
122
+ return {
123
+ "user_api": self.user_api,
124
+ "internal_api": self.internal_api,
125
+ "num_threads": self.num_threads,
126
+ **{k: v for k, v in vars(self).items() if k not in hidden_attrs},
127
+ }
128
+
129
+ def set_additional_attributes(self):
130
+ """Set additional attributes meant to be exposed in the info dict"""
131
+
132
+ @property
133
+ def num_threads(self):
134
+ """Exposes the current thread limit as a dynamic property
135
+
136
+ This is not meant to be used or overriden by subclasses.
137
+ """
138
+ return self.get_num_threads()
139
+
140
+ @abstractmethod
141
+ def get_num_threads(self):
142
+ """Return the maximum number of threads available to use"""
143
+
144
+ @abstractmethod
145
+ def set_num_threads(self, num_threads):
146
+ """Set the maximum number of threads to use"""
147
+
148
+ @abstractmethod
149
+ def get_version(self):
150
+ """Return the version of the shared library"""
151
+
152
+ def _find_affixes(self):
153
+ """Return the affixes for the symbols of the shared library"""
154
+ return "", ""
155
+
156
+ def _get_symbol(self, name):
157
+ """Return the symbol of the shared library accounding for the affixes"""
158
+ return getattr(
159
+ self.dynlib, f"{self._symbol_prefix}{name}{self._symbol_suffix}", None
160
+ )
161
+
162
+
163
+ class OpenBLASController(LibController):
164
+ """Controller class for OpenBLAS"""
165
+
166
+ user_api = "blas"
167
+ internal_api = "openblas"
168
+ filename_prefixes = ("libopenblas", "libblas", "libscipy_openblas")
169
+
170
+ _symbol_prefixes = ("", "scipy_")
171
+ _symbol_suffixes = ("", "64_", "_64")
172
+
173
+ # All variations of "openblas_get_num_threads", accounting for the affixes
174
+ check_symbols = tuple(
175
+ f"{prefix}openblas_get_num_threads{suffix}"
176
+ for prefix, suffix in itertools.product(_symbol_prefixes, _symbol_suffixes)
177
+ )
178
+
179
+ def _find_affixes(self):
180
+ for prefix, suffix in itertools.product(
181
+ self._symbol_prefixes, self._symbol_suffixes
182
+ ):
183
+ if hasattr(self.dynlib, f"{prefix}openblas_get_num_threads{suffix}"):
184
+ return prefix, suffix
185
+
186
+ def set_additional_attributes(self):
187
+ self.threading_layer = self._get_threading_layer()
188
+ self.architecture = self._get_architecture()
189
+
190
+ def get_num_threads(self):
191
+ get_num_threads_func = self._get_symbol("openblas_get_num_threads")
192
+ if get_num_threads_func is not None:
193
+ return get_num_threads_func()
194
+ return None
195
+
196
+ def set_num_threads(self, num_threads):
197
+ set_num_threads_func = self._get_symbol("openblas_set_num_threads")
198
+ if set_num_threads_func is not None:
199
+ return set_num_threads_func(num_threads)
200
+ return None
201
+
202
+ def get_version(self):
203
+ # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS
204
+ # did not expose its version before that.
205
+ get_version_func = self._get_symbol("openblas_get_config")
206
+ if get_version_func is not None:
207
+ get_version_func.restype = ctypes.c_char_p
208
+ config = get_version_func().split()
209
+ if config[0] == b"OpenBLAS":
210
+ return config[1].decode("utf-8")
211
+ return None
212
+ return None
213
+
214
+ def _get_threading_layer(self):
215
+ """Return the threading layer of OpenBLAS"""
216
+ get_threading_layer_func = self._get_symbol("openblas_get_parallel")
217
+ if get_threading_layer_func is not None:
218
+ threading_layer = get_threading_layer_func()
219
+ if threading_layer == 2:
220
+ return "openmp"
221
+ elif threading_layer == 1:
222
+ return "pthreads"
223
+ return "disabled"
224
+ return "unknown"
225
+
226
+ def _get_architecture(self):
227
+ """Return the architecture detected by OpenBLAS"""
228
+ get_architecture_func = self._get_symbol("openblas_get_corename")
229
+ if get_architecture_func is not None:
230
+ get_architecture_func.restype = ctypes.c_char_p
231
+ return get_architecture_func().decode("utf-8")
232
+ return None
233
+
234
+
235
+ class BLISController(LibController):
236
+ """Controller class for BLIS"""
237
+
238
+ user_api = "blas"
239
+ internal_api = "blis"
240
+ filename_prefixes = ("libblis", "libblas")
241
+ check_symbols = (
242
+ "bli_thread_get_num_threads",
243
+ "bli_thread_set_num_threads",
244
+ "bli_info_get_version_str",
245
+ "bli_info_get_enable_openmp",
246
+ "bli_info_get_enable_pthreads",
247
+ "bli_arch_query_id",
248
+ "bli_arch_string",
249
+ )
250
+
251
+ def set_additional_attributes(self):
252
+ self.threading_layer = self._get_threading_layer()
253
+ self.architecture = self._get_architecture()
254
+
255
+ def get_num_threads(self):
256
+ get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None)
257
+ num_threads = get_func()
258
+ # by default BLIS is single-threaded and get_num_threads
259
+ # returns -1. We map it to 1 for consistency with other libraries.
260
+ return 1 if num_threads == -1 else num_threads
261
+
262
+ def set_num_threads(self, num_threads):
263
+ set_func = getattr(
264
+ self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None
265
+ )
266
+ return set_func(num_threads)
267
+
268
+ def get_version(self):
269
+ get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None)
270
+ if get_version_ is None:
271
+ return None
272
+
273
+ get_version_.restype = ctypes.c_char_p
274
+ return get_version_().decode("utf-8")
275
+
276
+ def _get_threading_layer(self):
277
+ """Return the threading layer of BLIS"""
278
+ if getattr(self.dynlib, "bli_info_get_enable_openmp", lambda: False)():
279
+ return "openmp"
280
+ elif getattr(self.dynlib, "bli_info_get_enable_pthreads", lambda: False)():
281
+ return "pthreads"
282
+ return "disabled"
283
+
284
+ def _get_architecture(self):
285
+ """Return the architecture detected by BLIS"""
286
+ bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None)
287
+ bli_arch_string = getattr(self.dynlib, "bli_arch_string", None)
288
+ if bli_arch_query_id is None or bli_arch_string is None:
289
+ return None
290
+
291
+ # the true restype should be BLIS' arch_t (enum) but int should work
292
+ # for us:
293
+ bli_arch_query_id.restype = ctypes.c_int
294
+ bli_arch_string.restype = ctypes.c_char_p
295
+ return bli_arch_string(bli_arch_query_id()).decode("utf-8")
296
+
297
+
298
+ class FlexiBLASController(LibController):
299
+ """Controller class for FlexiBLAS"""
300
+
301
+ user_api = "blas"
302
+ internal_api = "flexiblas"
303
+ filename_prefixes = ("libflexiblas",)
304
+ check_symbols = (
305
+ "flexiblas_get_num_threads",
306
+ "flexiblas_set_num_threads",
307
+ "flexiblas_get_version",
308
+ "flexiblas_list",
309
+ "flexiblas_list_loaded",
310
+ "flexiblas_current_backend",
311
+ )
312
+
313
+ @property
314
+ def loaded_backends(self):
315
+ return self._get_backend_list(loaded=True)
316
+
317
+ @property
318
+ def current_backend(self):
319
+ return self._get_current_backend()
320
+
321
+ def info(self):
322
+ """Return relevant info wrapped in a dict"""
323
+ # We override the info method because the loaded and current backends
324
+ # are dynamic properties
325
+ exposed_attrs = super().info()
326
+ exposed_attrs["loaded_backends"] = self.loaded_backends
327
+ exposed_attrs["current_backend"] = self.current_backend
328
+
329
+ return exposed_attrs
330
+
331
+ def set_additional_attributes(self):
332
+ self.available_backends = self._get_backend_list(loaded=False)
333
+
334
+ def get_num_threads(self):
335
+ get_func = getattr(self.dynlib, "flexiblas_get_num_threads", lambda: None)
336
+ num_threads = get_func()
337
+ # by default BLIS is single-threaded and get_num_threads
338
+ # returns -1. We map it to 1 for consistency with other libraries.
339
+ return 1 if num_threads == -1 else num_threads
340
+
341
+ def set_num_threads(self, num_threads):
342
+ set_func = getattr(
343
+ self.dynlib, "flexiblas_set_num_threads", lambda num_threads: None
344
+ )
345
+ return set_func(num_threads)
346
+
347
+ def get_version(self):
348
+ get_version_ = getattr(self.dynlib, "flexiblas_get_version", None)
349
+ if get_version_ is None:
350
+ return None
351
+
352
+ major = ctypes.c_int()
353
+ minor = ctypes.c_int()
354
+ patch = ctypes.c_int()
355
+ get_version_(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
356
+ return f"{major.value}.{minor.value}.{patch.value}"
357
+
358
+ def _get_backend_list(self, loaded=False):
359
+ """Return the list of available backends for FlexiBLAS.
360
+
361
+ If loaded is False, return the list of available backends from the FlexiBLAS
362
+ configuration. If loaded is True, return the list of actually loaded backends.
363
+ """
364
+ func_name = f"flexiblas_list{'_loaded' if loaded else ''}"
365
+ get_backend_list_ = getattr(self.dynlib, func_name, None)
366
+ if get_backend_list_ is None:
367
+ return None
368
+
369
+ n_backends = get_backend_list_(None, 0, 0)
370
+
371
+ backends = []
372
+ for i in range(n_backends):
373
+ backend_name = ctypes.create_string_buffer(1024)
374
+ get_backend_list_(backend_name, 1024, i)
375
+ if backend_name.value.decode("utf-8") != "__FALLBACK__":
376
+ # We don't know when to expect __FALLBACK__ but it is not a real
377
+ # backend and does not show up when running flexiblas list.
378
+ backends.append(backend_name.value.decode("utf-8"))
379
+ return backends
380
+
381
+ def _get_current_backend(self):
382
+ """Return the backend of FlexiBLAS"""
383
+ get_backend_ = getattr(self.dynlib, "flexiblas_current_backend", None)
384
+ if get_backend_ is None:
385
+ return None
386
+
387
+ backend = ctypes.create_string_buffer(1024)
388
+ get_backend_(backend, ctypes.sizeof(backend))
389
+ return backend.value.decode("utf-8")
390
+
391
+ def switch_backend(self, backend):
392
+ """Switch the backend of FlexiBLAS
393
+
394
+ Parameters
395
+ ----------
396
+ backend : str
397
+ The name or the path to the shared library of the backend to switch to. If
398
+ the backend is not already loaded, it will be loaded first.
399
+ """
400
+ if backend not in self.loaded_backends:
401
+ if backend in self.available_backends:
402
+ load_func = getattr(self.dynlib, "flexiblas_load_backend", lambda _: -1)
403
+ else: # assume backend is a path to a shared library
404
+ load_func = getattr(
405
+ self.dynlib, "flexiblas_load_backend_library", lambda _: -1
406
+ )
407
+ res = load_func(str(backend).encode("utf-8"))
408
+ if res == -1:
409
+ raise RuntimeError(
410
+ f"Failed to load backend {backend!r}. It must either be the name of"
411
+ " a backend available in the FlexiBLAS configuration "
412
+ f"{self.available_backends} or the path to a valid shared library."
413
+ )
414
+
415
+ # Trigger a new search of loaded shared libraries since loading a new
416
+ # backend caused a dlopen.
417
+ self.parent._load_libraries()
418
+
419
+ switch_func = getattr(self.dynlib, "flexiblas_switch", lambda _: -1)
420
+ idx = self.loaded_backends.index(backend)
421
+ res = switch_func(idx)
422
+ if res == -1:
423
+ raise RuntimeError(f"Failed to switch to backend {backend!r}.")
424
+
425
+
426
+ class MKLController(LibController):
427
+ """Controller class for MKL"""
428
+
429
+ user_api = "blas"
430
+ internal_api = "mkl"
431
+ filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas")
432
+ check_symbols = (
433
+ "MKL_Get_Max_Threads",
434
+ "MKL_Set_Num_Threads",
435
+ "MKL_Get_Version_String",
436
+ "MKL_Set_Threading_Layer",
437
+ )
438
+
439
+ def set_additional_attributes(self):
440
+ self.threading_layer = self._get_threading_layer()
441
+
442
+ def get_num_threads(self):
443
+ get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None)
444
+ return get_func()
445
+
446
+ def set_num_threads(self, num_threads):
447
+ set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None)
448
+ return set_func(num_threads)
449
+
450
+ def get_version(self):
451
+ if not hasattr(self.dynlib, "MKL_Get_Version_String"):
452
+ return None
453
+
454
+ res = ctypes.create_string_buffer(200)
455
+ self.dynlib.MKL_Get_Version_String(res, 200)
456
+
457
+ version = res.value.decode("utf-8")
458
+ group = re.search(r"Version ([^ ]+) ", version)
459
+ if group is not None:
460
+ version = group.groups()[0]
461
+ return version.strip()
462
+
463
+ def _get_threading_layer(self):
464
+ """Return the threading layer of MKL"""
465
+ # The function mkl_set_threading_layer returns the current threading
466
+ # layer. Calling it with an invalid threading layer allows us to safely
467
+ # get the threading layer
468
+ set_threading_layer = getattr(
469
+ self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1
470
+ )
471
+ layer_map = {
472
+ 0: "intel",
473
+ 1: "sequential",
474
+ 2: "pgi",
475
+ 3: "gnu",
476
+ 4: "tbb",
477
+ -1: "not specified",
478
+ }
479
+ return layer_map[set_threading_layer(-1)]
480
+
481
+
482
+ class OpenMPController(LibController):
483
+ """Controller class for OpenMP"""
484
+
485
+ user_api = "openmp"
486
+ internal_api = "openmp"
487
+ filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp")
488
+ check_symbols = (
489
+ "omp_get_max_threads",
490
+ "omp_get_num_threads",
491
+ )
492
+
493
+ def get_num_threads(self):
494
+ get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None)
495
+ return get_func()
496
+
497
+ def set_num_threads(self, num_threads):
498
+ set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None)
499
+ return set_func(num_threads)
500
+
501
+ def get_version(self):
502
+ # There is no way to get the version number programmatically in OpenMP.
503
+ return None
504
+
505
+
506
+ # Controllers for the libraries that we'll look for in the loaded libraries.
507
+ # Third party libraries can register their own controllers.
508
+ _ALL_CONTROLLERS = [
509
+ OpenBLASController,
510
+ BLISController,
511
+ MKLController,
512
+ OpenMPController,
513
+ FlexiBLASController,
514
+ ]
515
+
516
+ # Helpers for the doc and test names
517
+ _ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS))
518
+ _ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS]
519
+ _ALL_PREFIXES = list(
520
+ set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes)
521
+ )
522
+ _ALL_BLAS_LIBRARIES = [
523
+ lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas"
524
+ ]
525
+ _ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes
526
+
527
+
528
+ def register(controller):
529
+ """Register a new controller"""
530
+ _ALL_CONTROLLERS.append(controller)
531
+ _ALL_USER_APIS.append(controller.user_api)
532
+ _ALL_INTERNAL_APIS.append(controller.internal_api)
533
+ _ALL_PREFIXES.extend(controller.filename_prefixes)
534
+
535
+
536
+ def _format_docstring(*args, **kwargs):
537
+ def decorator(o):
538
+ if o.__doc__ is not None:
539
+ o.__doc__ = o.__doc__.format(*args, **kwargs)
540
+ return o
541
+
542
+ return decorator
543
+
544
+
545
+ @lru_cache(maxsize=10000)
546
+ def _realpath(filepath):
547
+ """Small caching wrapper around os.path.realpath to limit system calls"""
548
+ return os.path.realpath(filepath)
549
+
550
+
551
+ @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
552
+ def threadpool_info():
553
+ """Return the maximal number of threads for each detected library.
554
+
555
+ Return a list with all the supported libraries that have been found. Each
556
+ library is represented by a dict with the following information:
557
+
558
+ - "user_api" : user API. Possible values are {USER_APIS}.
559
+ - "internal_api": internal API. Possible values are {INTERNAL_APIS}.
560
+ - "prefix" : filename prefix of the specific implementation.
561
+ - "filepath": path to the loaded library.
562
+ - "version": version of the library (if available).
563
+ - "num_threads": the current thread limit.
564
+
565
+ In addition, each library may contain internal_api specific entries.
566
+ """
567
+ return ThreadpoolController().info()
568
+
569
+
570
+ class _ThreadpoolLimiter:
571
+ """The guts of ThreadpoolController.limit
572
+
573
+ Refer to the docstring of ThreadpoolController.limit for more details.
574
+
575
+ It will only act on the library controllers held by the provided `controller`.
576
+ Using the default constructor sets the limits right away such that it can be used as
577
+ a callable. Setting the limits can be delayed by using the `wrap` class method such
578
+ that it can be used as a decorator.
579
+ """
580
+
581
+ def __init__(self, controller, *, limits=None, user_api=None):
582
+ self._controller = controller
583
+ self._limits, self._user_api, self._prefixes = self._check_params(
584
+ limits, user_api
585
+ )
586
+ self._original_info = self._controller.info()
587
+ self._set_threadpool_limits()
588
+
589
+ def __enter__(self):
590
+ return self
591
+
592
+ def __exit__(self, type, value, traceback):
593
+ self.restore_original_limits()
594
+
595
+ @classmethod
596
+ def wrap(cls, controller, *, limits=None, user_api=None):
597
+ """Return an instance of this class that can be used as a decorator"""
598
+ return _ThreadpoolLimiterDecorator(
599
+ controller=controller, limits=limits, user_api=user_api
600
+ )
601
+
602
+ def restore_original_limits(self):
603
+ """Set the limits back to their original values"""
604
+ for lib_controller, original_info in zip(
605
+ self._controller.lib_controllers, self._original_info
606
+ ):
607
+ lib_controller.set_num_threads(original_info["num_threads"])
608
+
609
+ # Alias of `restore_original_limits` for backward compatibility
610
+ unregister = restore_original_limits
611
+
612
+ def get_original_num_threads(self):
613
+ """Original num_threads from before calling threadpool_limits
614
+
615
+ Return a dict `{user_api: num_threads}`.
616
+ """
617
+ num_threads = {}
618
+ warning_apis = []
619
+
620
+ for user_api in self._user_api:
621
+ limits = [
622
+ lib_info["num_threads"]
623
+ for lib_info in self._original_info
624
+ if lib_info["user_api"] == user_api
625
+ ]
626
+ limits = set(limits)
627
+ n_limits = len(limits)
628
+
629
+ if n_limits == 1:
630
+ limit = limits.pop()
631
+ elif n_limits == 0:
632
+ limit = None
633
+ else:
634
+ limit = min(limits)
635
+ warning_apis.append(user_api)
636
+
637
+ num_threads[user_api] = limit
638
+
639
+ if warning_apis:
640
+ warnings.warn(
641
+ "Multiple value possible for following user apis: "
642
+ + ", ".join(warning_apis)
643
+ + ". Returning the minimum."
644
+ )
645
+
646
+ return num_threads
647
+
648
+ def _check_params(self, limits, user_api):
649
+ """Suitable values for the _limits, _user_api and _prefixes attributes"""
650
+
651
+ if isinstance(limits, str) and limits == "sequential_blas_under_openmp":
652
+ (
653
+ limits,
654
+ user_api,
655
+ ) = self._controller._get_params_for_sequential_blas_under_openmp().values()
656
+
657
+ if limits is None or isinstance(limits, int):
658
+ if user_api is None:
659
+ user_api = _ALL_USER_APIS
660
+ elif user_api in _ALL_USER_APIS:
661
+ user_api = [user_api]
662
+ else:
663
+ raise ValueError(
664
+ f"user_api must be either in {_ALL_USER_APIS} or None. Got "
665
+ f"{user_api} instead."
666
+ )
667
+
668
+ if limits is not None:
669
+ limits = {api: limits for api in user_api}
670
+ prefixes = []
671
+ else:
672
+ if isinstance(limits, list):
673
+ # This should be a list of dicts of library info, for
674
+ # compatibility with the result from threadpool_info.
675
+ limits = {
676
+ lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits
677
+ }
678
+ elif isinstance(limits, ThreadpoolController):
679
+ # To set the limits from the library controllers of a
680
+ # ThreadpoolController object.
681
+ limits = {
682
+ lib_controller.prefix: lib_controller.num_threads
683
+ for lib_controller in limits.lib_controllers
684
+ }
685
+
686
+ if not isinstance(limits, dict):
687
+ raise TypeError(
688
+ "limits must either be an int, a list, a dict, or "
689
+ f"'sequential_blas_under_openmp'. Got {type(limits)} instead"
690
+ )
691
+
692
+ # With a dictionary, can set both specific limit for given
693
+ # libraries and global limit for user_api. Fetch each separately.
694
+ prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES]
695
+ user_api = [api for api in limits if api in _ALL_USER_APIS]
696
+
697
+ return limits, user_api, prefixes
698
+
699
+ def _set_threadpool_limits(self):
700
+ """Change the maximal number of threads in selected thread pools.
701
+
702
+ Return a list with all the supported libraries that have been found
703
+ matching `self._prefixes` and `self._user_api`.
704
+ """
705
+ if self._limits is None:
706
+ return
707
+
708
+ for lib_controller in self._controller.lib_controllers:
709
+ # self._limits is a dict {key: num_threads} where key is either
710
+ # a prefix or a user_api. If a library matches both, the limit
711
+ # corresponding to the prefix is chosen.
712
+ if lib_controller.prefix in self._limits:
713
+ num_threads = self._limits[lib_controller.prefix]
714
+ elif lib_controller.user_api in self._limits:
715
+ num_threads = self._limits[lib_controller.user_api]
716
+ else:
717
+ continue
718
+
719
+ if num_threads is not None:
720
+ lib_controller.set_num_threads(num_threads)
721
+
722
+
723
+ class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator):
724
+ """Same as _ThreadpoolLimiter but to be used as a decorator"""
725
+
726
+ def __init__(self, controller, *, limits=None, user_api=None):
727
+ self._limits, self._user_api, self._prefixes = self._check_params(
728
+ limits, user_api
729
+ )
730
+ self._controller = controller
731
+
732
+ def __enter__(self):
733
+ # we need to set the limits here and not in the __init__ because we want the
734
+ # limits to be set when calling the decorated function, not when creating the
735
+ # decorator.
736
+ self._original_info = self._controller.info()
737
+ self._set_threadpool_limits()
738
+ return self
739
+
740
+
741
+ @_format_docstring(
742
+ USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS),
743
+ BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
744
+ OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
745
+ )
746
+ class threadpool_limits(_ThreadpoolLimiter):
747
+ """Change the maximal number of threads that can be used in thread pools.
748
+
749
+ This object can be used either as a callable (the construction of this object
750
+ limits the number of threads), as a context manager in a `with` block to
751
+ automatically restore the original state of the controlled libraries when exiting
752
+ the block, or as a decorator through its `wrap` method.
753
+
754
+ Set the maximal number of threads that can be used in thread pools used in
755
+ the supported libraries to `limit`. This function works for libraries that
756
+ are already loaded in the interpreter and can be changed dynamically.
757
+
758
+ This effect is global and impacts the whole Python process. There is no thread level
759
+ isolation as these libraries do not offer thread-local APIs to configure the number
760
+ of threads to use in nested parallel calls.
761
+
762
+ Parameters
763
+ ----------
764
+ limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
765
+ The maximal number of threads that can be used in thread pools
766
+
767
+ - If int, sets the maximum number of threads to `limits` for each
768
+ library selected by `user_api`.
769
+
770
+ - If it is a dictionary `{{key: max_threads}}`, this function sets a
771
+ custom maximum number of threads for each `key` which can be either a
772
+ `user_api` or a `prefix` for a specific library.
773
+
774
+ - If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
775
+ and `user_api` parameters for the specific use case of sequential BLAS
776
+ calls within an OpenMP parallel region. The `user_api` parameter is
777
+ ignored.
778
+
779
+ - If None, this function does not do anything.
780
+
781
+ user_api : {USER_APIS} or None (default=None)
782
+ APIs of libraries to limit. Used only if `limits` is an int.
783
+
784
+ - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
785
+
786
+ - If "openmp", it will only limit OpenMP supported libraries
787
+ ({OPENMP_LIBS}). Note that it can affect the number of threads used
788
+ by the BLAS libraries if they rely on OpenMP.
789
+
790
+ - If None, this function will apply to all supported libraries.
791
+ """
792
+
793
+ def __init__(self, limits=None, user_api=None):
794
+ super().__init__(ThreadpoolController(), limits=limits, user_api=user_api)
795
+
796
+ @classmethod
797
+ def wrap(cls, limits=None, user_api=None):
798
+ return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api)
799
+
800
+
801
+ class ThreadpoolController:
802
+ """Collection of LibController objects for all loaded supported libraries
803
+
804
+ Attributes
805
+ ----------
806
+ lib_controllers : list of `LibController` objects
807
+ The list of library controllers of all loaded supported libraries.
808
+ """
809
+
810
+ # Cache for libc under POSIX and a few system libraries under Windows.
811
+ # We use a class level cache instead of an instance level cache because
812
+ # it's very unlikely that a shared library will be unloaded and reloaded
813
+ # during the lifetime of a program.
814
+ _system_libraries = dict()
815
+
816
+ def __init__(self):
817
+ self.lib_controllers = []
818
+ self._load_libraries()
819
+ self._warn_if_incompatible_openmp()
820
+
821
+ @classmethod
822
+ def _from_controllers(cls, lib_controllers):
823
+ new_controller = cls.__new__(cls)
824
+ new_controller.lib_controllers = lib_controllers
825
+ return new_controller
826
+
827
+ def info(self):
828
+ """Return lib_controllers info as a list of dicts"""
829
+ return [lib_controller.info() for lib_controller in self.lib_controllers]
830
+
831
+ def select(self, **kwargs):
832
+ """Return a ThreadpoolController containing a subset of its current
833
+ library controllers
834
+
835
+ It will select all libraries matching at least one pair (key, value) from kwargs
836
+ where key is an entry of the library info dict (like "user_api", "internal_api",
837
+ "prefix", ...) and value is the value or a list of acceptable values for that
838
+ entry.
839
+
840
+ For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])`
841
+ will select all library controllers whose internal_api is either "blis" or
842
+ "openblas".
843
+ """
844
+ for key, vals in kwargs.items():
845
+ kwargs[key] = [vals] if not isinstance(vals, list) else vals
846
+
847
+ lib_controllers = [
848
+ lib_controller
849
+ for lib_controller in self.lib_controllers
850
+ if any(
851
+ getattr(lib_controller, key, None) in vals
852
+ for key, vals in kwargs.items()
853
+ )
854
+ ]
855
+
856
+ return ThreadpoolController._from_controllers(lib_controllers)
857
+
858
+ def _get_params_for_sequential_blas_under_openmp(self):
859
+ """Return appropriate params to use for a sequential BLAS call in an OpenMP loop
860
+
861
+ This function takes into account the unexpected behavior of OpenBLAS with the
862
+ OpenMP threading layer.
863
+ """
864
+ if self.select(
865
+ internal_api="openblas", threading_layer="openmp"
866
+ ).lib_controllers:
867
+ return {"limits": None, "user_api": None}
868
+ return {"limits": 1, "user_api": "blas"}
869
+
870
+ @_format_docstring(
871
+ USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
872
+ BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
873
+ OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
874
+ )
875
+ def limit(self, *, limits=None, user_api=None):
876
+ """Change the maximal number of threads that can be used in thread pools.
877
+
878
+ This function returns an object that can be used either as a callable (the
879
+ construction of this object limits the number of threads) or as a context
880
+ manager, in a `with` block to automatically restore the original state of the
881
+ controlled libraries when exiting the block.
882
+
883
+ Set the maximal number of threads that can be used in thread pools used in
884
+ the supported libraries to `limits`. This function works for libraries that
885
+ are already loaded in the interpreter and can be changed dynamically.
886
+
887
+ This effect is global and impacts the whole Python process. There is no thread
888
+ level isolation as these libraries do not offer thread-local APIs to configure
889
+ the number of threads to use in nested parallel calls.
890
+
891
+ Parameters
892
+ ----------
893
+ limits : int, dict, 'sequential_blas_under_openmp' or None (default=None)
894
+ The maximal number of threads that can be used in thread pools
895
+
896
+ - If int, sets the maximum number of threads to `limits` for each
897
+ library selected by `user_api`.
898
+
899
+ - If it is a dictionary `{{key: max_threads}}`, this function sets a
900
+ custom maximum number of threads for each `key` which can be either a
901
+ `user_api` or a `prefix` for a specific library.
902
+
903
+ - If 'sequential_blas_under_openmp', it will chose the appropriate `limits`
904
+ and `user_api` parameters for the specific use case of sequential BLAS
905
+ calls within an OpenMP parallel region. The `user_api` parameter is
906
+ ignored.
907
+
908
+ - If None, this function does not do anything.
909
+
910
+ user_api : {USER_APIS} or None (default=None)
911
+ APIs of libraries to limit. Used only if `limits` is an int.
912
+
913
+ - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
914
+
915
+ - If "openmp", it will only limit OpenMP supported libraries
916
+ ({OPENMP_LIBS}). Note that it can affect the number of threads used
917
+ by the BLAS libraries if they rely on OpenMP.
918
+
919
+ - If None, this function will apply to all supported libraries.
920
+ """
921
+ return _ThreadpoolLimiter(self, limits=limits, user_api=user_api)
922
+
923
+ @_format_docstring(
924
+ USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS),
925
+ BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES),
926
+ OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES),
927
+ )
928
+ def wrap(self, *, limits=None, user_api=None):
929
+ """Change the maximal number of threads that can be used in thread pools.
930
+
931
+ This function returns an object that can be used as a decorator.
932
+
933
+ Set the maximal number of threads that can be used in thread pools used in
934
+ the supported libraries to `limits`. This function works for libraries that
935
+ are already loaded in the interpreter and can be changed dynamically.
936
+
937
+ Parameters
938
+ ----------
939
+ limits : int, dict or None (default=None)
940
+ The maximal number of threads that can be used in thread pools
941
+
942
+ - If int, sets the maximum number of threads to `limits` for each
943
+ library selected by `user_api`.
944
+
945
+ - If it is a dictionary `{{key: max_threads}}`, this function sets a
946
+ custom maximum number of threads for each `key` which can be either a
947
+ `user_api` or a `prefix` for a specific library.
948
+
949
+ - If None, this function does not do anything.
950
+
951
+ user_api : {USER_APIS} or None (default=None)
952
+ APIs of libraries to limit. Used only if `limits` is an int.
953
+
954
+ - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}).
955
+
956
+ - If "openmp", it will only limit OpenMP supported libraries
957
+ ({OPENMP_LIBS}). Note that it can affect the number of threads used
958
+ by the BLAS libraries if they rely on OpenMP.
959
+
960
+ - If None, this function will apply to all supported libraries.
961
+ """
962
+ return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api)
963
+
964
+ def __len__(self):
965
+ return len(self.lib_controllers)
966
+
967
+ def _load_libraries(self):
968
+ """Loop through loaded shared libraries and store the supported ones"""
969
+ if sys.platform == "darwin":
970
+ self._find_libraries_with_dyld()
971
+ elif sys.platform == "win32":
972
+ self._find_libraries_with_enum_process_module_ex()
973
+ elif "pyodide" in sys.modules:
974
+ self._find_libraries_pyodide()
975
+ else:
976
+ self._find_libraries_with_dl_iterate_phdr()
977
+
978
+ def _find_libraries_with_dl_iterate_phdr(self):
979
+ """Loop through loaded libraries and return binders on supported ones
980
+
981
+ This function is expected to work on POSIX system only.
982
+ This code is adapted from code by Intel developer @anton-malakhov
983
+ available at https://github.com/IntelPython/smp
984
+
985
+ Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause
986
+ license
987
+ """
988
+ libc = self._get_libc()
989
+ if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover
990
+ warnings.warn(
991
+ "Could not find dl_iterate_phdr in the C standard library.",
992
+ RuntimeWarning,
993
+ )
994
+ return []
995
+
996
+ # Callback function for `dl_iterate_phdr` which is called for every
997
+ # library loaded in the current process until it returns 1.
998
+ def match_library_callback(info, size, data):
999
+ # Get the path of the current library
1000
+ filepath = info.contents.dlpi_name
1001
+ if filepath:
1002
+ filepath = filepath.decode("utf-8")
1003
+
1004
+ # Store the library controller if it is supported and selected
1005
+ self._make_controller_from_path(filepath)
1006
+ return 0
1007
+
1008
+ c_func_signature = ctypes.CFUNCTYPE(
1009
+ ctypes.c_int, # Return type
1010
+ ctypes.POINTER(_dl_phdr_info),
1011
+ ctypes.c_size_t,
1012
+ ctypes.c_char_p,
1013
+ )
1014
+ c_match_library_callback = c_func_signature(match_library_callback)
1015
+
1016
+ data = ctypes.c_char_p(b"")
1017
+ libc.dl_iterate_phdr(c_match_library_callback, data)
1018
+
1019
+ def _find_libraries_with_dyld(self):
1020
+ """Loop through loaded libraries and return binders on supported ones
1021
+
1022
+ This function is expected to work on OSX system only
1023
+ """
1024
+ libc = self._get_libc()
1025
+ if not hasattr(libc, "_dyld_image_count"): # pragma: no cover
1026
+ warnings.warn(
1027
+ "Could not find _dyld_image_count in the C standard library.",
1028
+ RuntimeWarning,
1029
+ )
1030
+ return []
1031
+
1032
+ n_dyld = libc._dyld_image_count()
1033
+ libc._dyld_get_image_name.restype = ctypes.c_char_p
1034
+
1035
+ for i in range(n_dyld):
1036
+ filepath = ctypes.string_at(libc._dyld_get_image_name(i))
1037
+ filepath = filepath.decode("utf-8")
1038
+
1039
+ # Store the library controller if it is supported and selected
1040
+ self._make_controller_from_path(filepath)
1041
+
1042
+ def _find_libraries_with_enum_process_module_ex(self):
1043
+ """Loop through loaded libraries and return binders on supported ones
1044
+
1045
+ This function is expected to work on windows system only.
1046
+ This code is adapted from code by Philipp Hagemeister @phihag available
1047
+ at https://stackoverflow.com/questions/17474574
1048
+ """
1049
+ from ctypes.wintypes import DWORD, HMODULE, MAX_PATH
1050
+
1051
+ PROCESS_QUERY_INFORMATION = 0x0400
1052
+ PROCESS_VM_READ = 0x0010
1053
+
1054
+ LIST_LIBRARIES_ALL = 0x03
1055
+
1056
+ ps_api = self._get_windll("Psapi")
1057
+ kernel_32 = self._get_windll("kernel32")
1058
+
1059
+ h_process = kernel_32.OpenProcess(
1060
+ PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid()
1061
+ )
1062
+ if not h_process: # pragma: no cover
1063
+ raise OSError(f"Could not open PID {os.getpid()}")
1064
+
1065
+ try:
1066
+ buf_count = 256
1067
+ needed = DWORD()
1068
+ # Grow the buffer until it becomes large enough to hold all the
1069
+ # module headers
1070
+ while True:
1071
+ buf = (HMODULE * buf_count)()
1072
+ buf_size = ctypes.sizeof(buf)
1073
+ if not ps_api.EnumProcessModulesEx(
1074
+ h_process,
1075
+ ctypes.byref(buf),
1076
+ buf_size,
1077
+ ctypes.byref(needed),
1078
+ LIST_LIBRARIES_ALL,
1079
+ ):
1080
+ raise OSError("EnumProcessModulesEx failed")
1081
+ if buf_size >= needed.value:
1082
+ break
1083
+ buf_count = needed.value // (buf_size // buf_count)
1084
+
1085
+ count = needed.value // (buf_size // buf_count)
1086
+ h_modules = map(HMODULE, buf[:count])
1087
+
1088
+ # Loop through all the module headers and get the library path
1089
+ # Allocate a buffer for the path 10 times the size of MAX_PATH to take
1090
+ # into account long path names.
1091
+ max_path = 10 * MAX_PATH
1092
+ buf = ctypes.create_unicode_buffer(max_path)
1093
+ n_size = DWORD()
1094
+ for h_module in h_modules:
1095
+ # Get the path of the current module
1096
+ if not ps_api.GetModuleFileNameExW(
1097
+ h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size)
1098
+ ):
1099
+ raise OSError("GetModuleFileNameEx failed")
1100
+ filepath = buf.value
1101
+
1102
+ if len(filepath) == max_path: # pragma: no cover
1103
+ warnings.warn(
1104
+ "Could not get the full path of a dynamic library (path too "
1105
+ "long). This library will be ignored and threadpoolctl might "
1106
+ "not be able to control or display information about all "
1107
+ f"loaded libraries. Here's the truncated path: {filepath!r}",
1108
+ RuntimeWarning,
1109
+ )
1110
+ else:
1111
+ # Store the library controller if it is supported and selected
1112
+ self._make_controller_from_path(filepath)
1113
+ finally:
1114
+ kernel_32.CloseHandle(h_process)
1115
+
1116
+ def _find_libraries_pyodide(self):
1117
+ """Pyodide specific implementation for finding loaded libraries.
1118
+
1119
+ Adapted from suggestion in https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1946696449.
1120
+
1121
+ One day, we may have a simpler solution. libc dl_iterate_phdr needs to
1122
+ be implemented in Emscripten and exposed in Pyodide, see
1123
+ https://github.com/emscripten-core/emscripten/issues/21354 for more
1124
+ details.
1125
+ """
1126
+ try:
1127
+ from pyodide_js._module import LDSO
1128
+ except ImportError:
1129
+ warnings.warn(
1130
+ "Unable to import LDSO from pyodide_js._module. This should never "
1131
+ "happen."
1132
+ )
1133
+ return
1134
+
1135
+ for filepath in LDSO.loadedLibsByName.as_object_map():
1136
+ # Some libraries are duplicated by Pyodide and do not exist in the
1137
+ # filesystem, so we first check for the existence of the file. For
1138
+ # more details, see
1139
+ # https://github.com/joblib/threadpoolctl/pull/169#issuecomment-1947946728
1140
+ if os.path.exists(filepath):
1141
+ self._make_controller_from_path(filepath)
1142
+
1143
+ def _make_controller_from_path(self, filepath):
1144
+ """Store a library controller if it is supported and selected"""
1145
+ # Required to resolve symlinks
1146
+ filepath = _realpath(filepath)
1147
+ # `lower` required to take account of OpenMP dll case on Windows
1148
+ # (vcomp, VCOMP, Vcomp, ...)
1149
+ filename = os.path.basename(filepath).lower()
1150
+
1151
+ # Loop through supported libraries to find if this filename corresponds
1152
+ # to a supported one.
1153
+ for controller_class in _ALL_CONTROLLERS:
1154
+ # check if filename matches a supported prefix
1155
+ prefix = self._check_prefix(filename, controller_class.filename_prefixes)
1156
+
1157
+ # filename does not match any of the prefixes of the candidate
1158
+ # library. move to next library.
1159
+ if prefix is None:
1160
+ continue
1161
+
1162
+ # workaround for BLAS libraries packaged by conda-forge on windows, which
1163
+ # are all renamed "libblas.dll". We thus have to check to which BLAS
1164
+ # implementation it actually corresponds looking for implementation
1165
+ # specific symbols.
1166
+ if prefix == "libblas":
1167
+ if filename.endswith(".dll"):
1168
+ libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD)
1169
+ if not any(
1170
+ hasattr(libblas, func)
1171
+ for func in controller_class.check_symbols
1172
+ ):
1173
+ continue
1174
+ else:
1175
+ # We ignore libblas on other platforms than windows because there
1176
+ # might be a libblas dso comming with openblas for instance that
1177
+ # can't be used to instantiate a pertinent LibController (many
1178
+ # symbols are missing) and would create confusion by making a
1179
+ # duplicate entry in threadpool_info.
1180
+ continue
1181
+
1182
+ # filename matches a prefix. Now we check if the library has the symbols we
1183
+ # are looking for. If none of the symbols exists, it's very likely not the
1184
+ # expected library (e.g. a library having a common prefix with one of the
1185
+ # our supported libraries). Otherwise, create and store the library
1186
+ # controller.
1187
+ lib_controller = controller_class(
1188
+ filepath=filepath, prefix=prefix, parent=self
1189
+ )
1190
+
1191
+ if filepath in (lib.filepath for lib in self.lib_controllers):
1192
+ # We already have a controller for this library.
1193
+ continue
1194
+
1195
+ if not hasattr(controller_class, "check_symbols") or any(
1196
+ hasattr(lib_controller.dynlib, func)
1197
+ for func in controller_class.check_symbols
1198
+ ):
1199
+ self.lib_controllers.append(lib_controller)
1200
+
1201
+ def _check_prefix(self, library_basename, filename_prefixes):
1202
+ """Return the prefix library_basename starts with
1203
+
1204
+ Return None if none matches.
1205
+ """
1206
+ for prefix in filename_prefixes:
1207
+ if library_basename.startswith(prefix):
1208
+ return prefix
1209
+ return None
1210
+
1211
+ def _warn_if_incompatible_openmp(self):
1212
+ """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded"""
1213
+ prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers]
1214
+ msg = textwrap.dedent(
1215
+ """
1216
+ Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
1217
+ the same time. Both libraries are known to be incompatible and this
1218
+ can cause random crashes or deadlocks on Linux when loaded in the
1219
+ same Python program.
1220
+ Using threadpoolctl may cause crashes or deadlocks. For more
1221
+ information and possible workarounds, please see
1222
+ https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md
1223
+ """
1224
+ )
1225
+ if "libomp" in prefixes and "libiomp" in prefixes:
1226
+ warnings.warn(msg, RuntimeWarning)
1227
+
1228
+ @classmethod
1229
+ def _get_libc(cls):
1230
+ """Load the lib-C for unix systems."""
1231
+ libc = cls._system_libraries.get("libc")
1232
+ if libc is None:
1233
+ # Remark: If libc is statically linked or if Python is linked against an
1234
+ # alternative implementation of libc like musl, find_library will return
1235
+ # None and CDLL will load the main program itself which should contain the
1236
+ # libc symbols. We still name it libc for convenience.
1237
+ # If the main program does not contain the libc symbols, it's ok because
1238
+ # we check their presence later anyway.
1239
+ libc = ctypes.CDLL(find_library("c"), mode=_RTLD_NOLOAD)
1240
+ cls._system_libraries["libc"] = libc
1241
+ return libc
1242
+
1243
+ @classmethod
1244
+ def _get_windll(cls, dll_name):
1245
+ """Load a windows DLL"""
1246
+ dll = cls._system_libraries.get(dll_name)
1247
+ if dll is None:
1248
+ dll = ctypes.WinDLL(f"{dll_name}.dll")
1249
+ cls._system_libraries[dll_name] = dll
1250
+ return dll
1251
+
1252
+
1253
+ def _main():
1254
+ """Commandline interface to display thread-pool information and exit."""
1255
+ import argparse
1256
+ import importlib
1257
+ import json
1258
+ import sys
1259
+
1260
+ parser = argparse.ArgumentParser(
1261
+ usage="python -m threadpoolctl -i numpy scipy.linalg xgboost",
1262
+ description="Display thread-pool information and exit.",
1263
+ )
1264
+ parser.add_argument(
1265
+ "-i",
1266
+ "--import",
1267
+ dest="modules",
1268
+ nargs="*",
1269
+ default=(),
1270
+ help="Python modules to import before introspecting thread-pools.",
1271
+ )
1272
+ parser.add_argument(
1273
+ "-c",
1274
+ "--command",
1275
+ help="a Python statement to execute before introspecting thread-pools.",
1276
+ )
1277
+
1278
+ options = parser.parse_args(sys.argv[1:])
1279
+ for module in options.modules:
1280
+ try:
1281
+ importlib.import_module(module, package=None)
1282
+ except ImportError:
1283
+ print("WARNING: could not import", module, file=sys.stderr)
1284
+
1285
+ if options.command:
1286
+ exec(options.command)
1287
+
1288
+ print(json.dumps(threadpool_info(), indent=2))
1289
+
1290
+
1291
+ if __name__ == "__main__":
1292
+ _main()