Spaces:
Sleeping
Sleeping
| from typing import Callable, Optional, Union, Tuple, Dict, Any | |
| from copy import copy | |
| import bisect | |
| import re | |
| import warp as wp | |
| _kernel_cache = dict() | |
| _struct_cache = dict() | |
| _func_cache = dict() | |
| _key_re = re.compile("[^0-9a-zA-Z_]+") | |
| def _make_key(obj, suffix: str, use_qualified_name): | |
| base_name = f"{obj.__module__}.{obj.__qualname__}" if use_qualified_name else obj.__name__ | |
| return _key_re.sub("", f"{base_name}_{suffix}") | |
| def get_func(func, suffix: str, use_qualified_name: bool = False): | |
| key = _make_key(func, suffix, use_qualified_name) | |
| if key not in _func_cache: | |
| _func_cache[key] = wp.Function( | |
| func=func, | |
| key=key, | |
| namespace="", | |
| module=wp.get_module( | |
| func.__module__, | |
| ), | |
| ) | |
| return _func_cache[key] | |
| def dynamic_func(suffix: str, use_qualified_name=False): | |
| def wrap_func(func: Callable): | |
| return get_func(func, suffix=suffix, use_qualified_name=use_qualified_name) | |
| return wrap_func | |
| def get_kernel( | |
| func, | |
| suffix: str, | |
| use_qualified_name: bool = False, | |
| kernel_options: Dict[str, Any] = {}, | |
| ): | |
| key = _make_key(func, suffix, use_qualified_name) | |
| if key not in _kernel_cache: | |
| # Avoid creating too long file names -- can lead to issues on Windows | |
| # We could hash the key, but prefer to keep it human-readable | |
| module_name = f"{func.__module__}.dyn.{key}" | |
| module_name = module_name[:128] if len(module_name) > 128 else module_name | |
| module = wp.get_module(module_name) | |
| module.options = copy(wp.get_module(func.__module__).options) | |
| module.options.update(kernel_options) | |
| _kernel_cache[key] = wp.Kernel(func=func, key=key, module=module) | |
| return _kernel_cache[key] | |
| def dynamic_kernel(suffix: str, use_qualified_name=False, kernel_options: Dict[str, Any] = {}): | |
| def wrap_kernel(func: Callable): | |
| return get_kernel(func, suffix=suffix, use_qualified_name=use_qualified_name, kernel_options=kernel_options) | |
| return wrap_kernel | |
| def get_struct(struct: type, suffix: str, use_qualified_name: bool = False): | |
| key = _make_key(struct, suffix, use_qualified_name) | |
| # used in codegen | |
| struct.__qualname__ = key | |
| if key not in _struct_cache: | |
| module = wp.get_module(struct.__module__) | |
| _struct_cache[key] = wp.codegen.Struct( | |
| cls=struct, | |
| key=key, | |
| module=module, | |
| ) | |
| return _struct_cache[key] | |
| def dynamic_struct(suffix: str, use_qualified_name=False): | |
| def wrap_struct(struct: type): | |
| return get_struct(struct, suffix=suffix, use_qualified_name=use_qualified_name) | |
| return wrap_struct | |
| def get_integrand_function( | |
| integrand: "warp.fem.operator.Integrand", | |
| suffix: str, | |
| func=None, | |
| annotations=None, | |
| code_transformers=[], | |
| ): | |
| key = _make_key(integrand.func, suffix, use_qualified_name=True) | |
| if key not in _func_cache: | |
| _func_cache[key] = wp.Function( | |
| func=integrand.func if func is None else func, | |
| key=key, | |
| namespace="", | |
| module=integrand.module, | |
| overloaded_annotations=annotations, | |
| code_transformers=code_transformers, | |
| ) | |
| return _func_cache[key] | |
| def get_integrand_kernel( | |
| integrand: "warp.fem.operator.Integrand", | |
| suffix: str, | |
| kernel_fn: Optional[Callable] = None, | |
| kernel_options: Dict[str, Any] = {}, | |
| code_transformers=[], | |
| ): | |
| key = _make_key(integrand.func, suffix, use_qualified_name=True) | |
| if key not in _kernel_cache: | |
| if kernel_fn is None: | |
| return None | |
| module = wp.get_module(f"{integrand.module.name}.{integrand.name}") | |
| module.options = copy(integrand.module.options) | |
| module.options.update(kernel_options) | |
| _kernel_cache[key] = wp.Kernel(func=kernel_fn, key=key, module=module, code_transformers=code_transformers) | |
| return _kernel_cache[key] | |
| def cached_arg_value(func: Callable): | |
| """Decorator to be applied to member methods assembling Arg structs, so that the result gets | |
| automatically cached for the lifetime of the parent object | |
| """ | |
| cache_attr = f"_{func.__name__}_cache" | |
| def get_arg(obj, device): | |
| if not hasattr(obj, cache_attr): | |
| setattr(obj, cache_attr, {}) | |
| cache = getattr(obj, cache_attr, {}) | |
| device = wp.get_device(device) | |
| if device.ordinal not in cache: | |
| cache[device.ordinal] = func(obj, device) | |
| return cache[device.ordinal] | |
| return get_arg | |
| _cached_vec_types = {} | |
| _cached_mat_types = {} | |
| def cached_vec_type(length, dtype): | |
| key = (length, dtype) | |
| if key not in _cached_vec_types: | |
| _cached_vec_types[key] = wp.vec(length=length, dtype=dtype) | |
| return _cached_vec_types[key] | |
| def cached_mat_type(shape, dtype): | |
| key = (*shape, dtype) | |
| if key not in _cached_mat_types: | |
| _cached_mat_types[key] = wp.mat(shape=shape, dtype=dtype) | |
| return _cached_mat_types[key] | |
| class Temporary: | |
| """Handle over a temporary array from a :class:`TemporaryStore`. | |
| The array will be automatically returned to the temporary pool for re-use upon destruction of this object, unless | |
| the temporary is explicitly detached from the pool using :meth:`detach`. | |
| The temporary may also be explicitly returned to the pool before destruction using :meth:`release`. | |
| """ | |
| def __init__(self, array: wp.array, pool: Optional["TemporaryStore.Pool"] = None, shape=None, dtype=None): | |
| self._raw_array = array | |
| self._array_view = array | |
| self._pool = pool | |
| if shape is not None or dtype is not None: | |
| self._view_as(shape=shape, dtype=dtype) | |
| def detach(self) -> wp.array: | |
| """Detaches the temporary so it is never returned to the pool""" | |
| if self._pool is not None: | |
| self._pool.detach(self._raw_array) | |
| self._pool = None | |
| return self._array_view | |
| def release(self): | |
| """Returns the temporary array to the pool""" | |
| if self._pool is not None: | |
| self._pool.redeem(self._raw_array) | |
| self._pool = None | |
| def array(self) -> wp.array: | |
| """View of the array with desired shape and data type.""" | |
| return self._array_view | |
| def _view_as(self, shape, dtype) -> "Temporary": | |
| def _view_reshaped_truncated(array): | |
| return wp.types.array( | |
| ptr=array.ptr, | |
| dtype=dtype, | |
| shape=shape, | |
| device=array.device, | |
| pinned=array.pinned, | |
| capacity=array.capacity, | |
| copy=False, | |
| owner=False, | |
| grad=None if array.grad is None else _view_reshaped_truncated(array.grad), | |
| ) | |
| self._array_view = _view_reshaped_truncated(self._raw_array) | |
| return self | |
| def __del__(self): | |
| self.release() | |
| class TemporaryStore: | |
| """ | |
| Shared pool of temporary arrays that will be persisted and reused across invocations of ``warp.fem`` functions. | |
| A :class:`TemporaryStore` instance may either be passed explicitly to ``warp.fem`` functions that accept such an argument, for instance :func:`.integrate.integrate`, | |
| or can be set globally as the default store using :func:`set_default_temporary_store`. | |
| By default, there is no default temporary store, so that temporary allocations are not persisted. | |
| """ | |
| _default_store: "TemporaryStore" = None | |
| class Pool: | |
| def __init__(self, dtype, device, pinned: bool): | |
| self.dtype = dtype | |
| self.device = device | |
| self.pinned = pinned | |
| self._pool = [] # Currently available arrays for borrowing, ordered by size | |
| self._pool_sizes = [] # Sizes of available arrays for borrowing, ascending | |
| self._allocs = {} # All allocated arrays, including borrowed ones | |
| def borrow(self, shape, dtype, requires_grad: bool): | |
| size = 1 | |
| if isinstance(shape, int): | |
| shape = (shape,) | |
| for d in shape: | |
| size *= d | |
| index = bisect.bisect_left( | |
| a=self._pool_sizes, | |
| x=size, | |
| ) | |
| if index < len(self._pool): | |
| # Big enough array found, remove from pool | |
| array = self._pool.pop(index) | |
| self._pool_sizes.pop(index) | |
| if requires_grad and array.grad is None: | |
| array.requires_grad = True | |
| return Temporary(pool=self, array=array, shape=shape, dtype=dtype) | |
| # No big enough array found, allocate new one | |
| if len(self._pool) > 0: | |
| grow_factor = 1.5 | |
| size = max(int(self._pool_sizes[-1] * grow_factor), size) | |
| array = wp.empty( | |
| shape=(size,), dtype=self.dtype, pinned=self.pinned, device=self.device, requires_grad=requires_grad | |
| ) | |
| self._allocs[array.ptr] = array | |
| return Temporary(pool=self, array=array, shape=shape, dtype=dtype) | |
| def redeem(self, array): | |
| # Insert back array into available pool | |
| index = bisect.bisect_left( | |
| a=self._pool_sizes, | |
| x=array.size, | |
| ) | |
| self._pool.insert(index, array) | |
| self._pool_sizes.insert(index, array.size) | |
| def detach(self, array): | |
| del self._allocs[array.ptr] | |
| def __init__(self): | |
| self.clear() | |
| def clear(self): | |
| self._temporaries = {} | |
| def borrow(self, shape, dtype, pinned: bool = False, device=None, requires_grad: bool = False) -> Temporary: | |
| dtype = wp.types.type_to_warp(dtype) | |
| device = wp.get_device(device) | |
| type_length = wp.types.type_length(dtype) | |
| key = (dtype._type_, type_length, pinned, device.ordinal) | |
| pool = self._temporaries.get(key, None) | |
| if pool is None: | |
| value_type = ( | |
| cached_vec_type(length=type_length, dtype=wp.types.type_scalar_type(dtype)) | |
| if type_length > 1 | |
| else dtype | |
| ) | |
| pool = TemporaryStore.Pool(value_type, device, pinned=pinned) | |
| self._temporaries[key] = pool | |
| return pool.borrow(dtype=dtype, shape=shape, requires_grad=requires_grad) | |
| def set_default_temporary_store(temporary_store: Optional[TemporaryStore]): | |
| """Globally sets the default :class:`TemporaryStore` instance to use for temporary allocations in ``warp.fem`` functions. | |
| If the default temporary store is set to ``None``, temporary allocations are not persisted unless a :class:`TemporaryStore` is provided at a per-function granularity. | |
| """ | |
| TemporaryStore._default_store = temporary_store | |
| def borrow_temporary( | |
| temporary_store: Optional[TemporaryStore], | |
| shape: Union[int, Tuple[int]], | |
| dtype: type, | |
| pinned: bool = False, | |
| requires_grad: bool = False, | |
| device=None, | |
| ) -> Temporary: | |
| """ | |
| Borrows and returns a temporary array with specified attributes from a shared pool. | |
| If an array with sufficient capacity and matching desired attributes is already available in the pool, it will be returned. | |
| Otherwise, a new allocation wil be performed. | |
| Args: | |
| temporary_store: the shared pool to borrow the temporary from. If `temporary_store` is ``None``, the global default temporary store, if set, will be used. | |
| shape: desired dimensions for the temporary array | |
| dtype: desired data type for the temporary array | |
| pinned: whether a pinned allocation is desired | |
| device: device on which the momory should be allocated; if ``None``, the current device will be used. | |
| """ | |
| if temporary_store is None: | |
| temporary_store = TemporaryStore._default_store | |
| if temporary_store is None: | |
| return Temporary( | |
| array=wp.empty(shape=shape, dtype=dtype, pinned=pinned, device=device, requires_grad=requires_grad) | |
| ) | |
| return temporary_store.borrow(shape=shape, dtype=dtype, device=device, pinned=pinned, requires_grad=requires_grad) | |
| def borrow_temporary_like( | |
| array: Union[wp.array, Temporary], | |
| temporary_store: Optional[TemporaryStore], | |
| ) -> Temporary: | |
| """ | |
| Borrows and returns a temporary array with the same attributes as another array or temporary. | |
| Args: | |
| array: Warp or temporary array to read the desired attributes from | |
| temporary_store: the shared pool to borrow the temporary from. If `temporary_store` is ``None``, the global default temporary store, if set, will be used. | |
| """ | |
| if isinstance(array, Temporary): | |
| array = array.array | |
| return borrow_temporary( | |
| temporary_store=temporary_store, | |
| shape=array.shape, | |
| dtype=array.dtype, | |
| pinned=array.pinned, | |
| device=array.device, | |
| requires_grad=array.requires_grad, | |
| ) | |