# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved. # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import ast import ctypes import gc import hashlib import inspect import io import os import platform import sys import types from copy import copy as shallowcopy from types import ModuleType from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import warp import warp.build import warp.codegen import warp.config # represents either a built-in or user-defined function def create_value_func(type): def value_func(args, kwds, templates): return type return value_func def get_function_args(func): """Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type.""" import inspect argspec = inspect.getfullargspec(func) # use source-level argument annotations if len(argspec.annotations) < len(argspec.args): raise RuntimeError(f"Incomplete argument annotations on function {func.__qualname__}") return argspec.annotations class Function: def __init__( self, func, key, namespace, input_types=None, value_func=None, template_func=None, module=None, variadic=False, initializer_list_func=None, export=False, doc="", group="", hidden=False, skip_replay=False, missing_grad=False, generic=False, native_func=None, defaults=None, custom_replay_func=None, native_snippet=None, adj_native_snippet=None, skip_forward_codegen=False, skip_reverse_codegen=False, custom_reverse_num_input_args=-1, custom_reverse_mode=False, overloaded_annotations=None, code_transformers=[], skip_adding_overload=False, require_original_output_arg=False, ): self.func = func # points to Python function decorated with @wp.func, may be None for builtins self.key = key self.namespace = namespace self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded self.template_func = template_func self.input_types = {} self.export = export self.doc = doc self.group = group self.module = module self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf() self.defaults = defaults # Function instance for a custom implementation of the replay pass self.custom_replay_func = custom_replay_func self.native_snippet = native_snippet self.adj_native_snippet = adj_native_snippet self.custom_grad_func = None self.require_original_output_arg = require_original_output_arg if initializer_list_func is None: self.initializer_list_func = lambda x, y: False else: self.initializer_list_func = ( initializer_list_func # True if the arguments should be emitted as an initializer list in the c++ code ) self.hidden = hidden # function will not be listed in docs self.skip_replay = ( skip_replay # whether or not operation will be performed during the forward replay in the backward pass ) self.missing_grad = missing_grad # whether or not builtin is missing a corresponding adjoint self.generic = generic # allow registering builtin functions with a different name in Python from the native code if native_func is None: self.native_func = key else: self.native_func = native_func if func: # user-defined function # generic and concrete overload lookups by type signature self.user_templates = {} self.user_overloads = {} # user defined (Python) function self.adj = warp.codegen.Adjoint( func, is_user_function=True, skip_forward_codegen=skip_forward_codegen, skip_reverse_codegen=skip_reverse_codegen, custom_reverse_num_input_args=custom_reverse_num_input_args, custom_reverse_mode=custom_reverse_mode, overload_annotations=overloaded_annotations, transformers=code_transformers, ) # record input types for name, type in self.adj.arg_types.items(): if name == "return": self.value_func = create_value_func(type) else: self.input_types[name] = type else: # builtin function # embedded linked list of all overloads # the builtin_functions dictionary holds # the list head for a given key (func name) self.overloads = [] # builtin (native) function, canonicalize argument types for k, v in input_types.items(): self.input_types[k] = warp.types.type_to_warp(v) # cache mangled name if self.is_simple(): self.mangled_name = self.mangle() else: self.mangled_name = None if not skip_adding_overload: self.add_overload(self) # add to current module if module: module.register_function(self, skip_adding_overload) def __call__(self, *args, **kwargs): # handles calling a builtin (native) function # as if it was a Python function, i.e.: from # within the CPython interpreter rather than # from within a kernel (experimental). if self.is_builtin() and self.mangled_name: # For each of this function's existing overloads, we attempt to pack # the given arguments into the C types expected by the corresponding # parameters, and we rinse and repeat until we get a match. for overload in self.overloads: if overload.generic: continue success, return_value = call_builtin(overload, *args) if success: return return_value # overload resolution or call failed raise RuntimeError( f"Couldn't find a function '{self.key}' compatible with " f"the arguments '{', '.join(type(x).__name__ for x in args)}'" ) if hasattr(self, "user_overloads") and len(self.user_overloads): # user-defined function with overloads if len(kwargs): raise RuntimeError( f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads." ) # try and find a matching overload for overload in self.user_overloads.values(): if len(overload.input_types) != len(args): continue template_types = list(overload.input_types.values()) arg_names = list(overload.input_types.keys()) try: # attempt to unify argument types with function template types warp.types.infer_argument_types(args, template_types, arg_names) return overload.func(*args) except Exception: continue raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}") # user-defined function with no overloads if self.func is None: raise RuntimeError(f"Error calling function '{self.key}', function is undefined") # this function has no overloads, call it like a plain Python function return self.func(*args, **kwargs) def is_builtin(self): return self.func is None def is_simple(self): if self.variadic: return False # only export simple types that don't use arrays for k, v in self.input_types.items(): if isinstance(v, warp.array) or v == Any or v == Callable or v == Tuple: return False return_type = "" try: # todo: construct a default value for each of the functions args # so we can generate the return type for overloaded functions return_type = type_str(self.value_func(None, None, None)) except Exception: return False if return_type.startswith("Tuple"): return False return True def mangle(self): # builds a mangled name for the C-exported # function, e.g.: builtin_normalize_vec3() name = "builtin_" + self.key types = [] for t in self.input_types.values(): types.append(t.__name__) return "_".join([name, *types]) def add_overload(self, f): if self.is_builtin(): # todo: note that it is an error to add two functions # with the exact same signature as this would cause compile # errors during compile time. We should check here if there # is a previously created function with the same signature self.overloads.append(f) # make sure variadic overloads appear last so non variadic # ones are matched first: self.overloads.sort(key=lambda f: f.variadic) else: # get function signature based on the input types sig = warp.types.get_signature( f.input_types.values(), func_name=f.key, arg_names=list(f.input_types.keys()) ) # check if generic if warp.types.is_generic_signature(sig): if sig in self.user_templates: raise RuntimeError( f"Duplicate generic function overload {self.key} with arguments {f.input_types.values()}" ) self.user_templates[sig] = f else: if sig in self.user_overloads: raise RuntimeError( f"Duplicate function overload {self.key} with arguments {f.input_types.values()}" ) self.user_overloads[sig] = f def get_overload(self, arg_types): assert not self.is_builtin() sig = warp.types.get_signature(arg_types, func_name=self.key) f = self.user_overloads.get(sig) if f is not None: return f else: for f in self.user_templates.values(): if len(f.input_types) != len(arg_types): continue # try to match the given types to the function template types template_types = list(f.input_types.values()) args_matched = True for i in range(len(arg_types)): if not warp.types.type_matches_template(arg_types[i], template_types[i]): args_matched = False break if args_matched: # instantiate this function with the specified argument types arg_names = f.input_types.keys() overload_annotations = dict(zip(arg_names, arg_types)) ovl = shallowcopy(f) ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations) ovl.input_types = overload_annotations ovl.value_func = None self.user_overloads[sig] = ovl return ovl # failed to find overload return None def __repr__(self): inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()]) return f"" def call_builtin(func: Function, *params) -> Tuple[bool, Any]: uses_non_warp_array_type = False # Retrieve the built-in function from Warp's dll. c_func = getattr(warp.context.runtime.core, func.mangled_name) # Try gathering the parameters that the function expects and pack them # into their corresponding C types. c_params = [] for i, (_, arg_type) in enumerate(func.input_types.items()): param = params[i] try: iter(param) except TypeError: is_array = False else: is_array = True if is_array: if not issubclass(arg_type, ctypes.Array): return (False, None) # The argument expects a built-in Warp type like a vector or a matrix. c_param = None if isinstance(param, ctypes.Array): # The given parameter is also a built-in Warp type, so we only need # to make sure that it matches with the argument. if not warp.types.types_equal(type(param), arg_type): return (False, None) if isinstance(param, arg_type): c_param = param else: # Cast the value to its argument type to make sure that it # can be assigned to the field of the `Param` struct. # This could error otherwise when, for example, the field type # is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`, # even though both types are semantically identical. c_param = arg_type(param) else: # Flatten the parameter values into a flat 1-D array. arr = [] ndim = 1 stack = [(0, param)] while stack: depth, elem = stack.pop(0) try: # If `elem` is a sequence, then it should be possible # to add its elements to the stack for later processing. stack.extend((depth + 1, x) for x in elem) except TypeError: # Since `elem` doesn't seem to be a sequence, # we must have a leaf value that we need to add to our # resulting array. arr.append(elem) ndim = max(depth, ndim) assert ndim > 0 # Ensure that if the given parameter value is, say, a 2-D array, # then we try to resolve it against a matrix argument rather than # a vector. if ndim > len(arg_type._shape_): return (False, None) elem_count = len(arr) if elem_count != arg_type._length_: return (False, None) # Retrieve the element type of the sequence while ensuring # that it's homogeneous. elem_type = type(arr[0]) for i in range(1, elem_count): if type(arr[i]) is not elem_type: raise ValueError("All array elements must share the same type.") expected_elem_type = arg_type._wp_scalar_type_ if not ( elem_type is expected_elem_type or (elem_type is float and expected_elem_type is warp.types.float32) or (elem_type is int and expected_elem_type is warp.types.int32) or ( issubclass(elem_type, np.number) and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type ) ): # The parameter value has a type not matching the type defined # for the corresponding argument. return (False, None) if elem_type in warp.types.int_types: # Pass the value through the expected integer type # in order to evaluate any integer wrapping. # For example `uint8(-1)` should result in the value `-255`. arr = tuple(elem_type._type_(x.value).value for x in arr) elif elem_type in warp.types.float_types: # Extract the floating-point values. arr = tuple(x.value for x in arr) c_param = arg_type() if warp.types.type_is_matrix(arg_type): rows, cols = arg_type._shape_ for i in range(rows): idx_start = i * cols idx_end = idx_start + cols c_param[i] = arr[idx_start:idx_end] else: c_param[:] = arr uses_non_warp_array_type = True c_params.append(ctypes.byref(c_param)) else: if issubclass(arg_type, ctypes.Array): return (False, None) if not ( isinstance(param, arg_type) or (type(param) is float and arg_type is warp.types.float32) or (type(param) is int and arg_type is warp.types.int32) or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type ): return (False, None) if type(param) in warp.types.scalar_types: param = param.value # try to pack as a scalar type if arg_type == warp.types.float16: c_params.append(arg_type._type_(warp.types.float_to_half_bits(param))) else: c_params.append(arg_type._type_(param)) # returns the corresponding ctype for a scalar or vector warp type value_type = func.value_func(None, None, None) if value_type == float: value_ctype = ctypes.c_float elif value_type == int: value_ctype = ctypes.c_int32 elif issubclass(value_type, (ctypes.Array, ctypes.Structure)): value_ctype = value_type else: # scalar type value_ctype = value_type._type_ # construct return value (passed by address) ret = value_ctype() ret_addr = ctypes.c_void_p(ctypes.addressof(ret)) c_params.append(ret_addr) # Call the built-in function from Warp's dll. c_func(*c_params) # TODO: uncomment when we have a way to print warning messages only once. # if uses_non_warp_array_type: # warp.utils.warn( # "Support for built-in functions called with non-Warp array types, " # "such as lists, tuples, NumPy arrays, and others, will be dropped " # "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, " # "`wp.quat`, or `wp.transform`.", # DeprecationWarning, # stacklevel=3 # ) if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure): # return vector types as ctypes return (True, ret) if value_type == warp.types.float16: return (True, warp.types.half_bits_to_float(ret.value)) # return scalar types as int/float return (True, ret.value) class KernelHooks: def __init__(self, forward, backward): self.forward = forward self.backward = backward # caches source and compiled entry points for a kernel (will be populated after module loads) class Kernel: def __init__(self, func, key=None, module=None, options=None, code_transformers=[]): self.func = func if module is None: self.module = get_module(func.__module__) else: self.module = module if key is None: unique_key = self.module.generate_unique_kernel_key(func.__name__) self.key = unique_key else: self.key = key self.options = {} if options is None else options self.adj = warp.codegen.Adjoint(func, transformers=code_transformers) # check if generic self.is_generic = False for arg_type in self.adj.arg_types.values(): if warp.types.type_is_generic(arg_type): self.is_generic = True break # unique signature (used to differentiate instances of generic kernels during codegen) self.sig = "" # known overloads for generic kernels, indexed by type signature self.overloads = {} # argument indices by name self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args)) if self.module: self.module.register_kernel(self) def infer_argument_types(self, args): template_types = list(self.adj.arg_types.values()) if len(args) != len(template_types): raise RuntimeError(f"Invalid number of arguments for kernel {self.key}") arg_names = list(self.adj.arg_types.keys()) return warp.types.infer_argument_types(args, template_types, arg_names) def add_overload(self, arg_types): if len(arg_types) != len(self.adj.arg_types): raise RuntimeError(f"Invalid number of arguments for kernel {self.key}") arg_names = list(self.adj.arg_types.keys()) template_types = list(self.adj.arg_types.values()) # make sure all argument types are concrete and match the kernel parameters for i in range(len(arg_types)): if not warp.types.type_matches_template(arg_types[i], template_types[i]): if warp.types.type_is_generic(arg_types[i]): raise TypeError( f"Kernel {self.key} argument '{arg_names[i]}' cannot be generic, got {arg_types[i]}" ) else: raise TypeError( f"Kernel {self.key} argument '{arg_names[i]}' type mismatch: expected {template_types[i]}, got {arg_types[i]}" ) # get a type signature from the given argument types sig = warp.types.get_signature(arg_types, func_name=self.key) if sig in self.overloads: raise RuntimeError( f"Duplicate overload for kernel {self.key}, an overload with the given arguments already exists" ) overload_annotations = dict(zip(arg_names, arg_types)) # instantiate this kernel with the given argument types ovl = shallowcopy(self) ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations) ovl.is_generic = False ovl.overloads = {} ovl.sig = sig self.overloads[sig] = ovl self.module.unload() return ovl def get_overload(self, arg_types): sig = warp.types.get_signature(arg_types, func_name=self.key) ovl = self.overloads.get(sig) if ovl is not None: return ovl else: return self.add_overload(arg_types) def get_mangled_name(self): if self.sig: return f"{self.key}_{self.sig}" else: return self.key # ---------------------- # decorator to register function, @func def func(f): name = warp.codegen.make_full_qualified_name(f) m = get_module(f.__module__) Function( func=f, key=name, namespace="", module=m, value_func=None ) # value_type not known yet, will be inferred during Adjoint.build() # return the top of the list of overloads for this key return m.functions[name] def func_native(snippet, adj_snippet=None): """ Decorator to register native code snippet, @func_native """ def snippet_func(f): name = warp.codegen.make_full_qualified_name(f) m = get_module(f.__module__) func = Function( func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet ) # cuda snippets do not have a return value_type return m.functions[name] return snippet_func def func_grad(forward_fn): """ Decorator to register a custom gradient function for a given forward function. The function signature must correspond to one of the function overloads in the following way: the first part of the input arguments are the original input variables with the same types as their corresponding arguments in the original function, and the second part of the input arguments are the adjoint variables of the output variables (if available) of the original function with the same types as the output variables. The function must not return anything. """ def wrapper(grad_fn): generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values()) if generic: raise RuntimeError( f"Cannot define custom grad definition for {forward_fn.key} since functions with generic input arguments are not yet supported." ) reverse_args = {} reverse_args.update(forward_fn.input_types) # create temporary Adjoint instance to analyze the function signature adj = warp.codegen.Adjoint( grad_fn, skip_forward_codegen=True, skip_reverse_codegen=False, transformers=forward_fn.adj.transformers ) from warp.types import types_equal grad_args = adj.args grad_sig = warp.types.get_signature([arg.type for arg in grad_args], func_name=forward_fn.key) generic = any(warp.types.type_is_generic(x.type) for x in grad_args) if generic: raise RuntimeError( f"Cannot define custom grad definition for {forward_fn.key} since the provided grad function has generic input arguments." ) def match_function(f): # check whether the function overload f matches the signature of the provided gradient function if not hasattr(f.adj, "return_var"): f.adj.build(None) expected_args = list(f.input_types.items()) if f.adj.return_var is not None: expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var] if len(grad_args) != len(expected_args): return False if any(not types_equal(a.type, exp_type) for a, (_, exp_type) in zip(grad_args, expected_args)): return False return True def add_custom_grad(f: Function): # register custom gradient function f.custom_grad_func = Function( grad_fn, key=f.key, namespace=f.namespace, input_types=reverse_args, value_func=None, module=f.module, template_func=f.template_func, skip_forward_codegen=True, custom_reverse_mode=True, custom_reverse_num_input_args=len(f.input_types), skip_adding_overload=False, code_transformers=f.adj.transformers, ) f.adj.skip_reverse_codegen = True if hasattr(forward_fn, "user_overloads") and len(forward_fn.user_overloads): # find matching overload for which this grad function is defined for sig, f in forward_fn.user_overloads.items(): if not grad_sig.startswith(sig): continue if match_function(f): add_custom_grad(f) return raise RuntimeError( f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}" ) else: # resolve return variables forward_fn.adj.build(None) expected_args = list(forward_fn.input_types.items()) if forward_fn.adj.return_var is not None: expected_args += [(f"adj_ret_{var.label}", var.type) for var in forward_fn.adj.return_var] # check if the signature matches this function if match_function(forward_fn): add_custom_grad(forward_fn) else: raise RuntimeError( f"Gradient function {grad_fn.__qualname__} for function {forward_fn.key} has an incorrect signature. The arguments must match the " "forward function arguments plus the adjoint variables corresponding to the return variables:" f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}" ) return wrapper def func_replay(forward_fn): """ Decorator to register a custom replay function for a given forward function. The replay function is the function version that is called in the forward phase of the backward pass (replay mode) and corresponds to the forward function by default. The provided function has to match the signature of one of the original forward function overloads. """ def wrapper(replay_fn): generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values()) if generic: raise RuntimeError( f"Cannot define custom replay definition for {forward_fn.key} since functions with generic input arguments are not yet supported." ) args = get_function_args(replay_fn) arg_types = list(args.values()) generic = any(warp.types.type_is_generic(x) for x in arg_types) if generic: raise RuntimeError( f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments." ) f = forward_fn.get_overload(arg_types) if f is None: inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()]) raise RuntimeError( f"Could not find forward definition of function {forward_fn.key} that matches custom replay definition with arguments:\n{inputs_str}" ) f.custom_replay_func = Function( replay_fn, key=f"replay_{f.key}", namespace=f.namespace, input_types=f.input_types, value_func=f.value_func, module=f.module, template_func=f.template_func, skip_reverse_codegen=True, skip_adding_overload=True, code_transformers=f.adj.transformers, ) return wrapper # decorator to register kernel, @kernel, custom_name may be a string # that creates a kernel with a different name from the actual function def kernel(f=None, *, enable_backward=None): def wrapper(f, *args, **kwargs): options = {} if enable_backward is not None: options["enable_backward"] = enable_backward m = get_module(f.__module__) k = Kernel( func=f, key=warp.codegen.make_full_qualified_name(f), module=m, options=options, ) return k if f is None: # Arguments were passed to the decorator. return wrapper return wrapper(f) # decorator to register struct, @struct def struct(c): m = get_module(c.__module__) s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m) return s # overload a kernel with the given argument types def overload(kernel, arg_types=None): if isinstance(kernel, Kernel): # handle cases where user calls us directly, e.g. wp.overload(kernel, [args...]) if not kernel.is_generic: raise RuntimeError(f"Only generic kernels can be overloaded. Kernel {kernel.key} is not generic") if isinstance(arg_types, list): arg_list = arg_types elif isinstance(arg_types, dict): # substitute named args arg_list = [a.type for a in kernel.adj.args] for arg_name, arg_type in arg_types.items(): idx = kernel.arg_indices.get(arg_name) if idx is None: raise RuntimeError(f"Invalid argument name '{arg_name}' in overload of kernel {kernel.key}") arg_list[idx] = arg_type elif arg_types is None: arg_list = [] else: raise TypeError("Kernel overload types must be given in a list or dict") # return new kernel overload return kernel.add_overload(arg_list) elif isinstance(kernel, types.FunctionType): # handle cases where user calls us as a function decorator (@wp.overload) # ensure this function name corresponds to a kernel fn = kernel module = get_module(fn.__module__) kernel = module.kernels.get(fn.__name__) if kernel is None: raise RuntimeError(f"Failed to find a kernel named '{fn.__name__}' in module {fn.__module__}") if not kernel.is_generic: raise RuntimeError(f"Only generic kernels can be overloaded. Kernel {kernel.key} is not generic") # ensure the function is defined without a body, only ellipsis (...), pass, or a string expression # TODO: show we allow defining a new body for kernel overloads? source = inspect.getsource(fn) tree = ast.parse(source) assert isinstance(tree, ast.Module) assert isinstance(tree.body[0], ast.FunctionDef) func_body = tree.body[0].body for node in func_body: if isinstance(node, ast.Pass): continue elif isinstance(node, ast.Expr) and isinstance(node.value, (ast.Str, ast.Ellipsis)): continue raise RuntimeError( "Illegal statement in kernel overload definition. Only pass, ellipsis (...), comments, or docstrings are allowed" ) # ensure all arguments are annotated argspec = inspect.getfullargspec(fn) if len(argspec.annotations) < len(argspec.args): raise RuntimeError(f"Incomplete argument annotations on kernel overload {fn.__name__}") # get type annotation list arg_list = [] for arg_name, arg_type in argspec.annotations.items(): if arg_name != "return": arg_list.append(arg_type) # add new overload, but we must return the original kernel from @wp.overload decorator! kernel.add_overload(arg_list) return kernel else: raise RuntimeError("wp.overload() called with invalid argument!") builtin_functions = {} def add_builtin( key, input_types={}, value_type=None, value_func=None, template_func=None, doc="", namespace="wp::", variadic=False, initializer_list_func=None, export=True, group="Other", hidden=False, skip_replay=False, missing_grad=False, native_func=None, defaults=None, require_original_output_arg=False, ): # wrap simple single-type functions with a value_func() if value_func is None: def value_func(args, kwds, templates): return value_type if initializer_list_func is None: def initializer_list_func(args, templates): return False if defaults is None: defaults = {} # Add specialized versions of this builtin if it's generic by matching arguments against # hard coded types. We do this so you can use hard coded warp types outside kernels: generic = any(warp.types.type_is_generic(x) for x in input_types.values()) if generic and export: # get a list of existing generic vector types (includes matrices and stuff) # so we can match arguments against them: generic_vtypes = [x for x in warp.types.vector_types if hasattr(x, "_wp_generic_type_str_")] # deduplicate identical types: def typekey(t): return f"{t._wp_generic_type_str_}_{t._wp_type_params_}" typedict = {typekey(t): t for t in generic_vtypes} generic_vtypes = [typedict[k] for k in sorted(typedict.keys())] # collect the parent type names of all the generic arguments: def generic_names(l): for t in l: if hasattr(t, "_wp_generic_type_str_"): yield t._wp_generic_type_str_ elif warp.types.type_is_generic_scalar(t): yield t.__name__ genericset = set(generic_names(input_types.values())) # for each of those type names, get a list of all hard coded types derived # from them: def derived(name): if name == "Float": return warp.types.float_types elif name == "Scalar": return warp.types.scalar_types elif name == "Int": return warp.types.int_types return [x for x in generic_vtypes if x._wp_generic_type_str_ == name] gtypes = {k: derived(k) for k in genericset} # find the scalar data types supported by all the arguments by intersecting # sets: def scalar_type(t): if t in warp.types.scalar_types: return t return [p for p in t._wp_type_params_ if p in warp.types.scalar_types][0] scalartypes = [{scalar_type(x) for x in gtypes[k]} for k in gtypes.keys()] if scalartypes: scalartypes = scalartypes.pop().intersection(*scalartypes) scalartypes = list(scalartypes) scalartypes.sort(key=str) # generate function calls for each of these scalar types: for stype in scalartypes: # find concrete types for this scalar type (eg if the scalar type is float32 # this dict will look something like this: # {"vec":[wp.vec2,wp.vec3,wp.vec4], "mat":[wp.mat22,wp.mat33,wp.mat44]}) consistenttypes = {k: [x for x in v if scalar_type(x) == stype] for k, v in gtypes.items()} def typelist(param): if warp.types.type_is_generic_scalar(param): return [stype] if hasattr(param, "_wp_generic_type_str_"): l = consistenttypes[param._wp_generic_type_str_] return [x for x in l if warp.types.types_equal(param, x, match_generic=True)] return [param] # gotta try generating function calls for all combinations of these argument types # now. import itertools typelists = [typelist(param) for param in input_types.values()] for argtypes in itertools.product(*typelists): # Some of these argument lists won't work, eg if the function is mul(), we won't be # able to do a matrix vector multiplication for a mat22 and a vec3, so we call value_func # on the generated argument list and skip generation if it fails. # This also gives us the return type, which we keep for later: try: return_type = value_func(argtypes, {}, []) except Exception: continue # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that # in the list of hard coded types so it knows it's returning one of them: if hasattr(return_type, "_wp_generic_type_str_"): return_type_match = [ x for x in generic_vtypes if x._wp_generic_type_str_ == return_type._wp_generic_type_str_ and x._wp_type_params_ == return_type._wp_type_params_ ] if not return_type_match: continue return_type = return_type_match[0] # finally we can generate a function call for these concrete types: add_builtin( key, input_types=dict(zip(input_types.keys(), argtypes)), value_type=return_type, doc=doc, namespace=namespace, variadic=variadic, initializer_list_func=initializer_list_func, export=export, group=group, hidden=True, skip_replay=skip_replay, missing_grad=missing_grad, require_original_output_arg=require_original_output_arg, ) func = Function( func=None, key=key, namespace=namespace, input_types=input_types, value_func=value_func, template_func=template_func, variadic=variadic, initializer_list_func=initializer_list_func, export=export, doc=doc, group=group, hidden=hidden, skip_replay=skip_replay, missing_grad=missing_grad, generic=generic, native_func=native_func, defaults=defaults, require_original_output_arg=require_original_output_arg, ) if key in builtin_functions: builtin_functions[key].add_overload(func) else: builtin_functions[key] = func # export means the function will be added to the `warp` module namespace # so that users can call it directly from the Python interpreter if export: if hasattr(warp, key): # check that we haven't already created something at this location # if it's just an overload stub for auto-complete then overwrite it if getattr(warp, key).__name__ != "_overload_dummy": raise RuntimeError( f"Trying to register builtin function '{key}' that would overwrite existing object." ) setattr(warp, key, func) # global dictionary of modules user_modules = {} def get_module(name): # some modules might be manually imported using `importlib` without being # registered into `sys.modules` parent = sys.modules.get(name, None) parent_loader = None if parent is None else parent.__loader__ if name in user_modules: # check if the Warp module was created using a different loader object # if so, we assume the file has changed and we recreate the module to # clear out old kernels / functions if user_modules[name].loader is not parent_loader: old_module = user_modules[name] # Unload the old module and recursively unload all of its dependents. # This ensures that dependent modules will be re-hashed and reloaded on next launch. # The visited set tracks modules already visited to avoid circular references. def unload_recursive(module, visited): module.unload() visited.add(module) for d in module.dependents: if d not in visited: unload_recursive(d, visited) unload_recursive(old_module, visited=set()) # clear out old kernels, funcs, struct definitions old_module.kernels = {} old_module.functions = {} old_module.constants = [] old_module.structs = {} old_module.loader = parent_loader return user_modules[name] else: # else Warp module didn't exist yet, so create a new one user_modules[name] = warp.context.Module(name, parent_loader) return user_modules[name] class ModuleBuilder: def __init__(self, module, options): self.functions = {} self.structs = {} self.options = options self.module = module # build all functions declared in the module for func in module.functions.values(): for f in func.user_overloads.values(): self.build_function(f) if f.custom_replay_func is not None: self.build_function(f.custom_replay_func) # build all kernel entry points for kernel in module.kernels.values(): if not kernel.is_generic: self.build_kernel(kernel) else: for k in kernel.overloads.values(): self.build_kernel(k) def build_struct_recursive(self, struct: warp.codegen.Struct): structs = [] stack = [struct] while stack: s = stack.pop() structs.append(s) for var in s.vars.values(): if isinstance(var.type, warp.codegen.Struct): stack.append(var.type) elif isinstance(var.type, warp.types.array) and isinstance(var.type.dtype, warp.codegen.Struct): stack.append(var.type.dtype) # Build them in reverse to generate a correct dependency order. for s in reversed(structs): self.build_struct(s) def build_struct(self, struct): self.structs[struct] = None def build_kernel(self, kernel): kernel.adj.build(self) if kernel.adj.return_var is not None: if kernel.adj.return_var.ctype() != "void": raise TypeError(f"Error, kernels can't have return values, got: {kernel.adj.return_var}") def build_function(self, func): if func in self.functions: return else: func.adj.build(self) # complete the function return type after we have analyzed it (inferred from return statement in ast) if not func.value_func: def wrap(adj): def value_type(arg_types, kwds, templates): if adj.return_var is None or len(adj.return_var) == 0: return None if len(adj.return_var) == 1: return adj.return_var[0].type else: return [v.type for v in adj.return_var] return value_type func.value_func = wrap(func.adj) # use dict to preserve import order self.functions[func] = None def codegen(self, device): source = "" # code-gen structs for struct in self.structs.keys(): source += warp.codegen.codegen_struct(struct) # code-gen all imported functions for func in self.functions.keys(): if func.native_snippet is None: source += warp.codegen.codegen_func( func.adj, c_func_name=func.native_func, device=device, options=self.options ) else: source += warp.codegen.codegen_snippet( func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet ) for kernel in self.module.kernels.values(): # each kernel gets an entry point in the module if not kernel.is_generic: source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options) source += warp.codegen.codegen_module(kernel, device=device) else: for k in kernel.overloads.values(): source += warp.codegen.codegen_kernel(k, device=device, options=self.options) source += warp.codegen.codegen_module(k, device=device) # add headers if device == "cpu": source = warp.codegen.cpu_module_header + source else: source = warp.codegen.cuda_module_header + source return source # ----------------------------------------------------- # stores all functions and kernels for a Python module # creates a hash of the function to use for checking # build cache class Module: def __init__(self, name, loader): self.name = name self.loader = loader self.kernels = {} self.functions = {} self.constants = [] self.structs = {} self.cpu_module = None self.cuda_modules = {} # module lookup by CUDA context self.cpu_build_failed = False self.cuda_build_failed = False self.options = { "max_unroll": 16, "enable_backward": warp.config.enable_backward, "fast_math": False, "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic) "mode": warp.config.mode, } # kernel hook lookup per device # hooks are stored with the module so they can be easily cleared when the module is reloaded. # -> See ``Module.get_kernel_hooks()`` self.kernel_hooks = {} # Module dependencies are determined by scanning each function # and kernel for references to external functions and structs. # # When a referenced module is modified, all of its dependents need to be reloaded # on the next launch. To detect this, a module's hash recursively includes # all of its references. # -> See ``Module.hash_module()`` # # The dependency mechanism works for both static and dynamic (runtime) modifications. # When a module is reloaded at runtime, we recursively unload all of its # dependents, so that they will be re-hashed and reloaded on the next launch. # -> See ``get_module()`` self.references = set() # modules whose content we depend on self.dependents = set() # modules that depend on our content # Since module hashing is recursive, we improve performance by caching the hash of the # module contents (kernel source, function source, and struct source). # After all kernels, functions, and structs are added to the module (usually at import time), # the content hash doesn't change. # -> See ``Module.hash_module_recursive()`` self.content_hash = None # number of times module auto-generates kernel key for user # used to ensure unique kernel keys self.count = 0 def register_struct(self, struct): self.structs[struct.key] = struct # for a reload of module on next launch self.unload() def register_kernel(self, kernel): self.kernels[kernel.key] = kernel self.find_references(kernel.adj) # for a reload of module on next launch self.unload() def register_function(self, func, skip_adding_overload=False): if func.key not in self.functions: self.functions[func.key] = func else: # Check whether the new function's signature match any that has # already been registered. If so, then we simply override it, as # Python would do it, otherwise we register it as a new overload. func_existing = self.functions[func.key] sig = warp.types.get_signature( func.input_types.values(), func_name=func.key, arg_names=list(func.input_types.keys()), ) sig_existing = warp.types.get_signature( func_existing.input_types.values(), func_name=func_existing.key, arg_names=list(func_existing.input_types.keys()), ) if sig == sig_existing: self.functions[func.key] = func elif not skip_adding_overload: func_existing.add_overload(func) self.find_references(func.adj) # for a reload of module on next launch self.unload() def generate_unique_kernel_key(self, key): unique_key = f"{key}_{self.count}" self.count += 1 return unique_key # collect all referenced functions / structs # given the AST of a function or kernel def find_references(self, adj): def add_ref(ref): if ref is not self: self.references.add(ref) ref.dependents.add(self) # scan for function calls for node in ast.walk(adj.tree): if isinstance(node, ast.Call): try: # try to resolve the function func, _ = adj.resolve_static_expression(node.func, eval_types=False) # if this is a user-defined function, add a module reference if isinstance(func, warp.context.Function) and func.module is not None: add_ref(func.module) except Exception: # Lookups may fail for builtins, but that's ok. # Lookups may also fail for functions in this module that haven't been imported yet, # and that's ok too (not an external reference). pass # scan for structs for arg in adj.args: if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None: add_ref(arg.type.module) def hash_module(self): def get_annotations(obj: Any) -> Mapping[str, Any]: """Alternative to `inspect.get_annotations()` for Python 3.9 and older.""" # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older if isinstance(obj, type): return obj.__dict__.get("__annotations__", {}) return getattr(obj, "__annotations__", {}) def get_type_name(type_hint): if isinstance(type_hint, warp.codegen.Struct): return get_type_name(type_hint.cls) return type_hint def hash_recursive(module, visited): # Hash this module, including all referenced modules recursively. # The visited set tracks modules already visited to avoid circular references. # check if we need to update the content hash if not module.content_hash: # recompute content hash ch = hashlib.sha256() # struct source for struct in module.structs.values(): s = ",".join( "{}: {}".format(name, get_type_name(type_hint)) for name, type_hint in get_annotations(struct.cls).items() ) ch.update(bytes(s, "utf-8")) # functions source for func in module.functions.values(): s = func.adj.source ch.update(bytes(s, "utf-8")) if func.custom_grad_func: s = func.custom_grad_func.adj.source ch.update(bytes(s, "utf-8")) if func.custom_replay_func: s = func.custom_replay_func.adj.source # cache func arg types for arg, arg_type in func.adj.arg_types.items(): s = f"{arg}: {get_type_name(arg_type)}" ch.update(bytes(s, "utf-8")) # kernel source for kernel in module.kernels.values(): ch.update(bytes(kernel.adj.source, "utf-8")) # cache kernel arg types for arg, arg_type in kernel.adj.arg_types.items(): s = f"{arg}: {get_type_name(arg_type)}" ch.update(bytes(s, "utf-8")) # for generic kernels the Python source is always the same, # but we hash the type signatures of all the overloads if kernel.is_generic: for sig in sorted(kernel.overloads.keys()): ch.update(bytes(sig, "utf-8")) module.content_hash = ch.digest() h = hashlib.sha256() # content hash h.update(module.content_hash) # configuration parameters for k in sorted(module.options.keys()): s = f"{k}={module.options[k]}" h.update(bytes(s, "utf-8")) # ensure to trigger recompilation if flags affecting kernel compilation are changed if warp.config.verify_fp: h.update(bytes("verify_fp", "utf-8")) h.update(bytes(warp.config.mode, "utf-8")) # compile-time constants (global) if warp.types._constant_hash: h.update(warp.types._constant_hash.digest()) # recurse on references visited.add(module) sorted_deps = sorted(module.references, key=lambda m: m.name) for dep in sorted_deps: if dep not in visited: dep_hash = hash_recursive(dep, visited) h.update(dep_hash) return h.digest() return hash_recursive(self, visited=set()) def load(self, device): from warp.utils import ScopedTimer device = get_device(device) if device.is_cpu: # check if already loaded if self.cpu_module: return True # avoid repeated build attempts if self.cpu_build_failed: return False if not warp.is_cpu_available(): raise RuntimeError("Failed to build CPU module because no CPU buildchain was found") else: # check if already loaded if device.context in self.cuda_modules: return True # avoid repeated build attempts if self.cuda_build_failed: return False if not warp.is_cuda_available(): raise RuntimeError("Failed to build CUDA module because CUDA is not available") with ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet): build_path = warp.build.kernel_bin_dir gen_path = warp.build.kernel_gen_dir if not os.path.exists(build_path): os.makedirs(build_path) if not os.path.exists(gen_path): os.makedirs(gen_path) module_name = "wp_" + self.name module_path = os.path.join(build_path, module_name) module_hash = self.hash_module() builder = ModuleBuilder(self, self.options) if device.is_cpu: obj_path = os.path.join(build_path, module_name) obj_path = obj_path + ".o" cpu_hash_path = module_path + ".cpu.hash" # check cache if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(obj_path): with open(cpu_hash_path, "rb") as f: cache_hash = f.read() if cache_hash == module_hash: runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8")) self.cpu_module = module_name return True # build try: cpp_path = os.path.join(gen_path, module_name + ".cpp") # write cpp sources cpp_source = builder.codegen("cpu") cpp_file = open(cpp_path, "w") cpp_file.write(cpp_source) cpp_file.close() # build object code with ScopedTimer("Compile x86", active=warp.config.verbose): warp.build.build_cpu( obj_path, cpp_path, mode=self.options["mode"], fast_math=self.options["fast_math"], verify_fp=warp.config.verify_fp, ) # update cpu hash with open(cpu_hash_path, "wb") as f: f.write(module_hash) # load the object code runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8")) self.cpu_module = module_name except Exception as e: self.cpu_build_failed = True raise (e) elif device.is_cuda: # determine whether to use PTX or CUBIN if device.is_cubin_supported: # get user preference specified either per module or globally preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output if preferred_cuda_output is not None: use_ptx = preferred_cuda_output == "ptx" else: # determine automatically: older drivers may not be able to handle PTX generated using newer # CUDA Toolkits, in which case we fall back on generating CUBIN modules use_ptx = runtime.driver_version >= runtime.toolkit_version else: # CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old) use_ptx = True if use_ptx: output_arch = min(device.arch, warp.config.ptx_target_arch) output_path = module_path + f".sm{output_arch}.ptx" else: output_arch = device.arch output_path = module_path + f".sm{output_arch}.cubin" cuda_hash_path = module_path + f".sm{output_arch}.hash" # check cache if warp.config.cache_kernels and os.path.isfile(cuda_hash_path) and os.path.isfile(output_path): with open(cuda_hash_path, "rb") as f: cache_hash = f.read() if cache_hash == module_hash: cuda_module = warp.build.load_cuda(output_path, device) if cuda_module is not None: self.cuda_modules[device.context] = cuda_module return True # build try: cu_path = os.path.join(gen_path, module_name + ".cu") # write cuda sources cu_source = builder.codegen("cuda") cu_file = open(cu_path, "w") cu_file.write(cu_source) cu_file.close() # generate PTX or CUBIN with ScopedTimer("Compile CUDA", active=warp.config.verbose): warp.build.build_cuda( cu_path, output_arch, output_path, config=self.options["mode"], fast_math=self.options["fast_math"], verify_fp=warp.config.verify_fp, ) # update cuda hash with open(cuda_hash_path, "wb") as f: f.write(module_hash) # load the module cuda_module = warp.build.load_cuda(output_path, device) if cuda_module is not None: self.cuda_modules[device.context] = cuda_module else: raise Exception("Failed to load CUDA module") except Exception as e: self.cuda_build_failed = True raise (e) return True def unload(self): if self.cpu_module: runtime.llvm.unload_obj(self.cpu_module.encode("utf-8")) self.cpu_module = None # need to unload the CUDA module from all CUDA contexts where it is loaded # note: we ensure that this doesn't change the current CUDA context if self.cuda_modules: saved_context = runtime.core.cuda_context_get_current() for context, module in self.cuda_modules.items(): runtime.core.cuda_unload_module(context, module) runtime.core.cuda_context_set_current(saved_context) self.cuda_modules = {} # clear kernel hooks self.kernel_hooks = {} # clear content hash self.content_hash = None # lookup and cache kernel entry points based on name, called after compilation / module load def get_kernel_hooks(self, kernel, device): # get all hooks for this device device_hooks = self.kernel_hooks.get(device.context) if device_hooks is None: self.kernel_hooks[device.context] = device_hooks = {} # look up this kernel hooks = device_hooks.get(kernel) if hooks is not None: return hooks name = kernel.get_mangled_name() if device.is_cpu: func = ctypes.CFUNCTYPE(None) forward = func( runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_forward").encode("utf-8")) ) backward = func( runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_backward").encode("utf-8")) ) else: cu_module = self.cuda_modules[device.context] forward = runtime.core.cuda_get_kernel( device.context, cu_module, (name + "_cuda_kernel_forward").encode("utf-8") ) backward = runtime.core.cuda_get_kernel( device.context, cu_module, (name + "_cuda_kernel_backward").encode("utf-8") ) hooks = KernelHooks(forward, backward) device_hooks[kernel] = hooks return hooks # ------------------------------------------- # execution context # a simple allocator # TODO: use a pooled allocator to avoid hitting the system allocator class Allocator: def __init__(self, device): self.device = device def alloc(self, size_in_bytes, pinned=False): if self.device.is_cuda: if self.device.is_capturing: raise RuntimeError(f"Cannot allocate memory on device {self} while graph capture is active") return runtime.core.alloc_device(self.device.context, size_in_bytes) elif self.device.is_cpu: if pinned: return runtime.core.alloc_pinned(size_in_bytes) else: return runtime.core.alloc_host(size_in_bytes) def free(self, ptr, size_in_bytes, pinned=False): if self.device.is_cuda: if self.device.is_capturing: raise RuntimeError(f"Cannot free memory on device {self} while graph capture is active") return runtime.core.free_device(self.device.context, ptr) elif self.device.is_cpu: if pinned: return runtime.core.free_pinned(ptr) else: return runtime.core.free_host(ptr) class ContextGuard: def __init__(self, device): self.device = device def __enter__(self): if self.device.is_cuda: runtime.core.cuda_context_push_current(self.device.context) elif is_cuda_driver_initialized(): self.saved_context = runtime.core.cuda_context_get_current() def __exit__(self, exc_type, exc_value, traceback): if self.device.is_cuda: runtime.core.cuda_context_pop_current() elif is_cuda_driver_initialized(): runtime.core.cuda_context_set_current(self.saved_context) class Stream: def __init__(self, device=None, **kwargs): self.owner = False # we can't use get_device() if called during init, but we can use an explicit Device arg if runtime is not None: device = runtime.get_device(device) elif not isinstance(device, Device): raise RuntimeError( "A device object is required when creating a stream before or during Warp initialization" ) if not device.is_cuda: raise RuntimeError(f"Device {device} is not a CUDA device") # we pass cuda_stream through kwargs because cuda_stream=None is actually a valid value (CUDA default stream) if "cuda_stream" in kwargs: self.cuda_stream = kwargs["cuda_stream"] else: self.cuda_stream = device.runtime.core.cuda_stream_create(device.context) if not self.cuda_stream: raise RuntimeError(f"Failed to create stream on device {device}") self.owner = True self.device = device def __del__(self): if self.owner: runtime.core.cuda_stream_destroy(self.device.context, self.cuda_stream) def record_event(self, event=None): if event is None: event = Event(self.device) elif event.device != self.device: raise RuntimeError( f"Event from device {event.device} cannot be recorded on stream from device {self.device}" ) runtime.core.cuda_event_record(self.device.context, event.cuda_event, self.cuda_stream) return event def wait_event(self, event): runtime.core.cuda_stream_wait_event(self.device.context, self.cuda_stream, event.cuda_event) def wait_stream(self, other_stream, event=None): if event is None: event = Event(other_stream.device) runtime.core.cuda_stream_wait_stream( self.device.context, self.cuda_stream, other_stream.cuda_stream, event.cuda_event ) class Event: # event creation flags class Flags: DEFAULT = 0x0 BLOCKING_SYNC = 0x1 DISABLE_TIMING = 0x2 def __init__(self, device=None, cuda_event=None, enable_timing=False): self.owner = False device = get_device(device) if not device.is_cuda: raise RuntimeError(f"Device {device} is not a CUDA device") self.device = device if cuda_event is not None: self.cuda_event = cuda_event else: flags = Event.Flags.DEFAULT if not enable_timing: flags |= Event.Flags.DISABLE_TIMING self.cuda_event = runtime.core.cuda_event_create(device.context, flags) if not self.cuda_event: raise RuntimeError(f"Failed to create event on device {device}") self.owner = True def __del__(self): if self.owner: runtime.core.cuda_event_destroy(self.device.context, self.cuda_event) class Device: def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None): self.runtime = runtime self.alias = alias self.ordinal = ordinal self.is_primary = is_primary # context can be None to avoid acquiring primary contexts until the device is used self._context = context # if the device context is not primary, it cannot be None if ordinal != -1 and not is_primary: assert context is not None # streams will be created when context is acquired self._stream = None self.null_stream = None # indicates whether CUDA graph capture is active for this device self.is_capturing = False self.allocator = Allocator(self) self.context_guard = ContextGuard(self) if self.ordinal == -1: # CPU device self.name = platform.processor() or "CPU" self.arch = 0 self.is_uva = False self.is_cubin_supported = False self.is_mempool_supported = False # TODO: add more device-specific dispatch functions self.memset = runtime.core.memset_host self.memtile = runtime.core.memtile_host elif ordinal >= 0 and ordinal < runtime.core.cuda_device_get_count(): # CUDA device self.name = runtime.core.cuda_device_get_name(ordinal).decode() self.arch = runtime.core.cuda_device_get_arch(ordinal) self.is_uva = runtime.core.cuda_device_is_uva(ordinal) # check whether our NVRTC can generate CUBINs for this architecture self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal) # Warn the user of a possible misconfiguration of their system if not self.is_mempool_supported: warp.utils.warn( f"Support for stream ordered memory allocators was not detected on device {ordinal}. " "This can prevent the use of graphs and/or result in poor performance. " "Is the UVM driver enabled?" ) # initialize streams unless context acquisition is postponed if self._context is not None: self.init_streams() # TODO: add more device-specific dispatch functions self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size) self.memtile = lambda ptr, src, srcsize, reps: runtime.core.memtile_device( self.context, ptr, src, srcsize, reps ) else: raise RuntimeError(f"Invalid device ordinal ({ordinal})'") def init_streams(self): # create a stream for asynchronous work self.stream = Stream(self) # CUDA default stream for some synchronous operations self.null_stream = Stream(self, cuda_stream=None) @property def is_cpu(self): return self.ordinal < 0 @property def is_cuda(self): return self.ordinal >= 0 @property def context(self): if self._context is not None: return self._context elif self.is_primary: # acquire primary context on demand self._context = self.runtime.core.cuda_device_primary_context_retain(self.ordinal) if self._context is None: raise RuntimeError(f"Failed to acquire primary context for device {self}") self.runtime.context_map[self._context] = self # initialize streams self.init_streams() return self._context @property def has_context(self): return self._context is not None @property def stream(self): if self.context: return self._stream else: raise RuntimeError(f"Device {self} is not a CUDA device") @stream.setter def stream(self, s): if self.is_cuda: if s.device != self: raise RuntimeError(f"Stream from device {s.device} cannot be used on device {self}") self._stream = s self.runtime.core.cuda_context_set_stream(self.context, s.cuda_stream) else: raise RuntimeError(f"Device {self} is not a CUDA device") @property def has_stream(self): return self._stream is not None def __str__(self): return self.alias def __repr__(self): return f"'{self.alias}'" def __eq__(self, other): if self is other: return True elif isinstance(other, Device): return self.context == other.context elif isinstance(other, str): if other == "cuda": return self == self.runtime.get_current_cuda_device() else: return other == self.alias else: return False def make_current(self): if self.context is not None: self.runtime.core.cuda_context_set_current(self.context) def can_access(self, other): other = self.runtime.get_device(other) if self.context == other.context: return True elif self.context is not None and other.context is not None: return bool(self.runtime.core.cuda_context_can_access_peer(self.context, other.context)) else: return False """ Meta-type for arguments that can be resolved to a concrete Device. """ Devicelike = Union[Device, str, None] class Graph: def __init__(self, device: Device, exec: ctypes.c_void_p): self.device = device self.exec = exec def __del__(self): # use CUDA context guard to avoid side effects during garbage collection with self.device.context_guard: runtime.core.cuda_graph_destroy(self.device.context, self.exec) class Runtime: def __init__(self): bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin") if os.name == "nt": if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8: # Python >= 3.8 this method to add dll search paths os.add_dll_directory(bin_path) else: # Python < 3.8 we add dll directory to path os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"] warp_lib = os.path.join(bin_path, "warp.dll") llvm_lib = os.path.join(bin_path, "warp-clang.dll") elif sys.platform == "darwin": warp_lib = os.path.join(bin_path, "libwarp.dylib") llvm_lib = os.path.join(bin_path, "libwarp-clang.dylib") else: warp_lib = os.path.join(bin_path, "warp.so") llvm_lib = os.path.join(bin_path, "warp-clang.so") self.core = self.load_dll(warp_lib) if os.path.exists(llvm_lib): self.llvm = self.load_dll(llvm_lib) # setup c-types for warp-clang.dll self.llvm.lookup.restype = ctypes.c_uint64 else: self.llvm = None # setup c-types for warp.dll self.core.alloc_host.argtypes = [ctypes.c_size_t] self.core.alloc_host.restype = ctypes.c_void_p self.core.alloc_pinned.argtypes = [ctypes.c_size_t] self.core.alloc_pinned.restype = ctypes.c_void_p self.core.alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t] self.core.alloc_device.restype = ctypes.c_void_p self.core.float_to_half_bits.argtypes = [ctypes.c_float] self.core.float_to_half_bits.restype = ctypes.c_uint16 self.core.half_bits_to_float.argtypes = [ctypes.c_uint16] self.core.half_bits_to_float.restype = ctypes.c_float self.core.free_host.argtypes = [ctypes.c_void_p] self.core.free_host.restype = None self.core.free_pinned.argtypes = [ctypes.c_void_p] self.core.free_pinned.restype = None self.core.free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.free_device.restype = None self.core.memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] self.core.memset_host.restype = None self.core.memset_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] self.core.memset_device.restype = None self.core.memtile_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t] self.core.memtile_host.restype = None self.core.memtile_device.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t, ] self.core.memtile_device.restype = None self.core.memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] self.core.memcpy_h2h.restype = None self.core.memcpy_h2d.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] self.core.memcpy_h2d.restype = None self.core.memcpy_d2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] self.core.memcpy_d2h.restype = None self.core.memcpy_d2d.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] self.core.memcpy_d2d.restype = None self.core.memcpy_peer.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] self.core.memcpy_peer.restype = None self.core.array_copy_host.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_copy_host.restype = ctypes.c_size_t self.core.array_copy_device.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_copy_device.restype = ctypes.c_size_t self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int] self.core.array_fill_host.restype = None self.core.array_fill_device.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int, ] self.core.array_fill_device.restype = None self.core.array_sum_double_host.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_sum_float_host.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_sum_double_device.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_sum_float_device.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_inner_double_host.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_inner_float_host.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_inner_double_device.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_inner_float_device.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool] self.core.array_scan_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool] self.core.array_scan_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool] self.core.array_scan_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool] self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int] self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int] self.core.runlength_encode_int_host.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ] self.core.runlength_encode_int_device.argtypes = [ ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ] self.core.bvh_create_host.restype = ctypes.c_uint64 self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] self.core.bvh_create_device.restype = ctypes.c_uint64 self.core.bvh_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int] self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64] self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64] self.core.bvh_refit_host.argtypes = [ctypes.c_uint64] self.core.bvh_refit_device.argtypes = [ctypes.c_uint64] self.core.mesh_create_host.restype = ctypes.c_uint64 self.core.mesh_create_host.argtypes = [ warp.types.array_t, warp.types.array_t, warp.types.array_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.mesh_create_device.restype = ctypes.c_uint64 self.core.mesh_create_device.argtypes = [ ctypes.c_void_p, warp.types.array_t, warp.types.array_t, warp.types.array_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ] self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64] self.core.mesh_destroy_device.argtypes = [ctypes.c_uint64] self.core.mesh_refit_host.argtypes = [ctypes.c_uint64] self.core.mesh_refit_device.argtypes = [ctypes.c_uint64] self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int] self.core.hash_grid_create_host.restype = ctypes.c_uint64 self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64] self.core.hash_grid_update_host.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p, ctypes.c_int] self.core.hash_grid_reserve_host.argtypes = [ctypes.c_uint64, ctypes.c_int] self.core.hash_grid_create_device.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int] self.core.hash_grid_create_device.restype = ctypes.c_uint64 self.core.hash_grid_destroy_device.argtypes = [ctypes.c_uint64] self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p, ctypes.c_int] self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int] self.core.cutlass_gemm.argtypes = [ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_float, ctypes.c_float, ctypes.c_bool, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ] self.core.cutlass_gemm.restypes = ctypes.c_bool self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64] self.core.volume_create_host.restype = ctypes.c_uint64 self.core.volume_get_buffer_info_host.argtypes = [ ctypes.c_uint64, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_uint64), ] self.core.volume_get_tiles_host.argtypes = [ ctypes.c_uint64, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_uint64), ] self.core.volume_destroy_host.argtypes = [ctypes.c_uint64] self.core.volume_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64] self.core.volume_create_device.restype = ctypes.c_uint64 self.core.volume_f_from_tiles_device.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_bool, ] self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64 self.core.volume_v_from_tiles_device.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_bool, ] self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64 self.core.volume_i_from_tiles_device.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_float, ctypes.c_int, ctypes.c_float, ctypes.c_float, ctypes.c_float, ctypes.c_bool, ] self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64 self.core.volume_get_buffer_info_device.argtypes = [ ctypes.c_uint64, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_uint64), ] self.core.volume_get_tiles_device.argtypes = [ ctypes.c_uint64, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_uint64), ] self.core.volume_destroy_device.argtypes = [ctypes.c_uint64] self.core.volume_get_voxel_size.argtypes = [ ctypes.c_uint64, ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ctypes.POINTER(ctypes.c_float), ] bsr_matrix_from_triplets_argtypes = [ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ] self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes self.core.bsr_matrix_from_triplets_float_host.restype = ctypes.c_int self.core.bsr_matrix_from_triplets_double_host.restype = ctypes.c_int self.core.bsr_matrix_from_triplets_float_device.restype = ctypes.c_int self.core.bsr_matrix_from_triplets_double_device.restype = ctypes.c_int bsr_transpose_argtypes = [ ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ctypes.c_uint64, ] self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes self.core.is_cuda_enabled.argtypes = None self.core.is_cuda_enabled.restype = ctypes.c_int self.core.is_cuda_compatibility_enabled.argtypes = None self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int self.core.is_cutlass_enabled.argtypes = None self.core.is_cutlass_enabled.restype = ctypes.c_int self.core.cuda_driver_version.argtypes = None self.core.cuda_driver_version.restype = ctypes.c_int self.core.cuda_toolkit_version.argtypes = None self.core.cuda_toolkit_version.restype = ctypes.c_int self.core.cuda_driver_is_initialized.argtypes = None self.core.cuda_driver_is_initialized.restype = ctypes.c_bool self.core.nvrtc_supported_arch_count.argtypes = None self.core.nvrtc_supported_arch_count.restype = ctypes.c_int self.core.nvrtc_supported_archs.argtypes = [ctypes.POINTER(ctypes.c_int)] self.core.nvrtc_supported_archs.restype = None self.core.cuda_device_get_count.argtypes = None self.core.cuda_device_get_count.restype = ctypes.c_int self.core.cuda_device_primary_context_retain.argtypes = [ctypes.c_int] self.core.cuda_device_primary_context_retain.restype = ctypes.c_void_p self.core.cuda_device_get_name.argtypes = [ctypes.c_int] self.core.cuda_device_get_name.restype = ctypes.c_char_p self.core.cuda_device_get_arch.argtypes = [ctypes.c_int] self.core.cuda_device_get_arch.restype = ctypes.c_int self.core.cuda_device_is_uva.argtypes = [ctypes.c_int] self.core.cuda_device_is_uva.restype = ctypes.c_int self.core.cuda_context_get_current.argtypes = None self.core.cuda_context_get_current.restype = ctypes.c_void_p self.core.cuda_context_set_current.argtypes = [ctypes.c_void_p] self.core.cuda_context_set_current.restype = None self.core.cuda_context_push_current.argtypes = [ctypes.c_void_p] self.core.cuda_context_push_current.restype = None self.core.cuda_context_pop_current.argtypes = None self.core.cuda_context_pop_current.restype = None self.core.cuda_context_create.argtypes = [ctypes.c_int] self.core.cuda_context_create.restype = ctypes.c_void_p self.core.cuda_context_destroy.argtypes = [ctypes.c_void_p] self.core.cuda_context_destroy.restype = None self.core.cuda_context_synchronize.argtypes = [ctypes.c_void_p] self.core.cuda_context_synchronize.restype = None self.core.cuda_context_check.argtypes = [ctypes.c_void_p] self.core.cuda_context_check.restype = ctypes.c_uint64 self.core.cuda_context_get_device_ordinal.argtypes = [ctypes.c_void_p] self.core.cuda_context_get_device_ordinal.restype = ctypes.c_int self.core.cuda_context_is_primary.argtypes = [ctypes.c_void_p] self.core.cuda_context_is_primary.restype = ctypes.c_int self.core.cuda_context_get_stream.argtypes = [ctypes.c_void_p] self.core.cuda_context_get_stream.restype = ctypes.c_void_p self.core.cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_context_set_stream.restype = None self.core.cuda_context_can_access_peer.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_context_can_access_peer.restype = ctypes.c_int self.core.cuda_stream_create.argtypes = [ctypes.c_void_p] self.core.cuda_stream_create.restype = ctypes.c_void_p self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_stream_destroy.restype = None self.core.cuda_stream_synchronize.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_stream_synchronize.restype = None self.core.cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_stream_wait_event.restype = None self.core.cuda_stream_wait_stream.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.core.cuda_stream_wait_stream.restype = None self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint] self.core.cuda_event_create.restype = ctypes.c_void_p self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_event_destroy.restype = None self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_event_record.restype = None self.core.cuda_graph_begin_capture.argtypes = [ctypes.c_void_p] self.core.cuda_graph_begin_capture.restype = None self.core.cuda_graph_end_capture.argtypes = [ctypes.c_void_p] self.core.cuda_graph_end_capture.restype = ctypes.c_void_p self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_graph_launch.restype = None self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_graph_destroy.restype = None self.core.cuda_compile_program.argtypes = [ ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_bool, ctypes.c_bool, ctypes.c_bool, ctypes.c_bool, ctypes.c_char_p, ] self.core.cuda_compile_program.restype = ctypes.c_size_t self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p] self.core.cuda_load_module.restype = ctypes.c_void_p self.core.cuda_unload_module.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_unload_module.restype = None self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p] self.core.cuda_get_kernel.restype = ctypes.c_void_p self.core.cuda_launch_kernel.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.POINTER(ctypes.c_void_p), ] self.core.cuda_launch_kernel.restype = ctypes.c_size_t self.core.cuda_graphics_map.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_graphics_map.restype = None self.core.cuda_graphics_unmap.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_graphics_unmap.restype = None self.core.cuda_graphics_device_ptr_and_size.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_size_t), ] self.core.cuda_graphics_device_ptr_and_size.restype = None self.core.cuda_graphics_register_gl_buffer.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint] self.core.cuda_graphics_register_gl_buffer.restype = ctypes.c_void_p self.core.cuda_graphics_unregister_resource.argtypes = [ctypes.c_void_p, ctypes.c_void_p] self.core.cuda_graphics_unregister_resource.restype = None self.core.init.restype = ctypes.c_int error = self.core.init() if error != 0: raise Exception("Warp initialization failed") self.device_map = {} # device lookup by alias self.context_map = {} # device lookup by context # register CPU device cpu_name = platform.processor() if not cpu_name: cpu_name = "CPU" self.cpu_device = Device(self, "cpu") self.device_map["cpu"] = self.cpu_device self.context_map[None] = self.cpu_device cuda_device_count = self.core.cuda_device_get_count() if cuda_device_count > 0: # get CUDA Toolkit and driver versions self.toolkit_version = self.core.cuda_toolkit_version() self.driver_version = self.core.cuda_driver_version() # get all architectures supported by NVRTC num_archs = self.core.nvrtc_supported_arch_count() if num_archs > 0: archs = (ctypes.c_int * num_archs)() self.core.nvrtc_supported_archs(archs) self.nvrtc_supported_archs = list(archs) else: self.nvrtc_supported_archs = [] # register CUDA devices self.cuda_devices = [] self.cuda_primary_devices = [] for i in range(cuda_device_count): alias = f"cuda:{i}" device = Device(self, alias, ordinal=i, is_primary=True) self.cuda_devices.append(device) self.cuda_primary_devices.append(device) self.device_map[alias] = device # set default device if cuda_device_count > 0: if self.core.cuda_context_get_current() is not None: self.set_default_device("cuda") else: self.set_default_device("cuda:0") else: # CUDA not available self.set_default_device("cpu") # initialize kernel cache warp.build.init_kernel_cache(warp.config.kernel_cache_dir) # print device and version information if not warp.config.quiet: print(f"Warp {warp.config.version} initialized:") if cuda_device_count > 0: toolkit_version = (self.toolkit_version // 1000, (self.toolkit_version % 1000) // 10) driver_version = (self.driver_version // 1000, (self.driver_version % 1000) // 10) print( f" CUDA Toolkit: {toolkit_version[0]}.{toolkit_version[1]}, Driver: {driver_version[0]}.{driver_version[1]}" ) else: if self.core.is_cuda_enabled(): # Warp was compiled with CUDA support, but no devices are available print(" CUDA devices not available") else: # Warp was compiled without CUDA support print(" CUDA support not enabled in this build") print(" Devices:") print(f' "{self.cpu_device.alias}" | {self.cpu_device.name}') for cuda_device in self.cuda_devices: print(f' "{cuda_device.alias}" | {cuda_device.name} (sm_{cuda_device.arch})') print(f" Kernel cache: {warp.config.kernel_cache_dir}") # CUDA compatibility check if cuda_device_count > 0 and not self.core.is_cuda_compatibility_enabled(): if self.driver_version < self.toolkit_version: print("******************************************************************") print("* WARNING: *") print("* Warp was compiled without CUDA compatibility support *") print("* (quick build). The CUDA Toolkit version used to build *") print("* Warp is not fully supported by the current driver. *") print("* Some CUDA functionality may not work correctly! *") print("* Update the driver or rebuild Warp without the --quick flag. *") print("******************************************************************") # global tape self.tape = None def load_dll(self, dll_path): try: if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8: dll = ctypes.CDLL(dll_path, winmode=0) else: dll = ctypes.CDLL(dll_path) except OSError as e: if "GLIBCXX" in str(e): raise RuntimeError( f"Failed to load the shared library '{dll_path}'.\n" "The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n" "See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details." ) from e else: raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e return dll def get_device(self, ident: Devicelike = None) -> Device: if isinstance(ident, Device): return ident elif ident is None: return self.default_device elif isinstance(ident, str): if ident == "cuda": return self.get_current_cuda_device() else: return self.device_map[ident] else: raise RuntimeError(f"Unable to resolve device from argument of type {type(ident)}") def set_default_device(self, ident: Devicelike): self.default_device = self.get_device(ident) def get_current_cuda_device(self): current_context = self.core.cuda_context_get_current() if current_context is not None: current_device = self.context_map.get(current_context) if current_device is not None: # this is a known device return current_device elif self.core.cuda_context_is_primary(current_context): # this is a primary context that we haven't used yet ordinal = self.core.cuda_context_get_device_ordinal(current_context) device = self.cuda_devices[ordinal] self.context_map[current_context] = device return device else: # this is an unseen non-primary context, register it as a new device with a unique alias alias = f"cuda!{current_context:x}" return self.map_cuda_device(alias, current_context) elif self.default_device.is_cuda: return self.default_device elif self.cuda_devices: return self.cuda_devices[0] else: raise RuntimeError("CUDA is not available") def rename_device(self, device, alias): del self.device_map[device.alias] device.alias = alias self.device_map[alias] = device return device def map_cuda_device(self, alias, context=None) -> Device: if context is None: context = self.core.cuda_context_get_current() if context is None: raise RuntimeError(f"Unable to determine CUDA context for device alias '{alias}'") # check if this alias already exists if alias in self.device_map: device = self.device_map[alias] if context == device.context: # device already exists with the same alias, that's fine return device else: raise RuntimeError(f"Device alias '{alias}' already exists") # check if this context already has an associated Warp device if context in self.context_map: # rename the device device = self.context_map[context] return self.rename_device(device, alias) else: # it's an unmapped context # get the device ordinal ordinal = self.core.cuda_context_get_device_ordinal(context) # check if this is a primary context (we could get here if it's a device that hasn't been used yet) if self.core.cuda_context_is_primary(context): # rename the device device = self.cuda_primary_devices[ordinal] return self.rename_device(device, alias) else: # create a new Warp device for this context device = Device(self, alias, ordinal=ordinal, is_primary=False, context=context) self.device_map[alias] = device self.context_map[context] = device self.cuda_devices.append(device) return device def unmap_cuda_device(self, alias): device = self.device_map.get(alias) # make sure the alias refers to a CUDA device if device is None or not device.is_cuda: raise RuntimeError(f"Invalid CUDA device alias '{alias}'") del self.device_map[alias] del self.context_map[device.context] self.cuda_devices.remove(device) def verify_cuda_device(self, device: Devicelike = None): if warp.config.verify_cuda: device = runtime.get_device(device) if not device.is_cuda: return err = self.core.cuda_context_check(device.context) if err != 0: raise RuntimeError(f"CUDA error detected: {err}") def assert_initialized(): assert runtime is not None, "Warp not initialized, call wp.init() before use" # global entry points def is_cpu_available(): return runtime.llvm def is_cuda_available(): return get_cuda_device_count() > 0 def is_device_available(device): return device in get_devices() def is_cuda_driver_initialized() -> bool: """Returns ``True`` if the CUDA driver is initialized. This is a stricter test than ``is_cuda_available()`` since a CUDA driver call to ``cuCtxGetCurrent`` is made, and the result is compared to `CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent`` even if there is no context bound to the calling CPU thread. This can be helpful in cases in which ``cuInit()`` was called before a fork. """ assert_initialized() return runtime.core.cuda_driver_is_initialized() def get_devices() -> List[Device]: """Returns a list of devices supported in this environment.""" assert_initialized() devices = [] if is_cpu_available(): devices.append(runtime.cpu_device) for cuda_device in runtime.cuda_devices: devices.append(cuda_device) return devices def get_cuda_device_count() -> int: """Returns the number of CUDA devices supported in this environment.""" assert_initialized() return len(runtime.cuda_devices) def get_cuda_device(ordinal: Union[int, None] = None) -> Device: """Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None.""" assert_initialized() if ordinal is None: return runtime.get_current_cuda_device() else: return runtime.cuda_devices[ordinal] def get_cuda_devices() -> List[Device]: """Returns a list of CUDA devices supported in this environment.""" assert_initialized() return runtime.cuda_devices def get_preferred_device() -> Device: """Returns the preferred compute device, CUDA if available and CPU otherwise.""" assert_initialized() if is_cuda_available(): return runtime.cuda_devices[0] elif is_cpu_available(): return runtime.cpu_device else: return None def get_device(ident: Devicelike = None) -> Device: """Returns the device identified by the argument.""" assert_initialized() return runtime.get_device(ident) def set_device(ident: Devicelike): """Sets the target device identified by the argument.""" assert_initialized() device = runtime.get_device(ident) runtime.set_default_device(device) device.make_current() def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device: """Assign a device alias to a CUDA context. This function can be used to create a wp.Device for an external CUDA context. If a wp.Device already exists for the given context, it's alias will change to the given value. Args: alias: A unique string to identify the device. context: A CUDA context pointer (CUcontext). If None, the currently bound CUDA context will be used. Returns: The associated wp.Device. """ assert_initialized() return runtime.map_cuda_device(alias, context) def unmap_cuda_device(alias: str): """Remove a CUDA device with the given alias.""" assert_initialized() runtime.unmap_cuda_device(alias) def get_stream(device: Devicelike = None) -> Stream: """Return the stream currently used by the given device""" return get_device(device).stream def set_stream(stream, device: Devicelike = None): """Set the stream to be used by the given device. If this is an external stream, caller is responsible for guaranteeing the lifetime of the stream. Consider using wp.ScopedStream instead. """ get_device(device).stream = stream def record_event(event: Event = None): """Record a CUDA event on the current stream. Args: event: Event to record. If None, a new Event will be created. Returns: The recorded event. """ return get_stream().record_event(event) def wait_event(event: Event): """Make the current stream wait for a CUDA event. Args: event: Event to wait for. """ get_stream().wait_event(event) def wait_stream(stream: Stream, event: Event = None): """Make the current stream wait for another CUDA stream to complete its work. Args: event: Event to be used. If None, a new Event will be created. """ get_stream().wait_stream(stream, event=event) class RegisteredGLBuffer: """ Helper object to register a GL buffer with CUDA so that it can be mapped to a Warp array. """ # Specifies no hints about how this resource will be used. # It is therefore assumed that this resource will be # read from and written to by CUDA. This is the default value. NONE = 0x00 # Specifies that CUDA will not write to this resource. READ_ONLY = 0x01 # Specifies that CUDA will not read from this resource and will write over the # entire contents of the resource, so none of the data previously # stored in the resource will be preserved. WRITE_DISCARD = 0x02 def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE): """Create a new RegisteredGLBuffer object. Args: gl_buffer_id: The OpenGL buffer id (GLuint). device: The device to register the buffer with. If None, the current device will be used. flags: A combination of the flags constants. """ self.gl_buffer_id = gl_buffer_id self.device = get_device(device) self.context = self.device.context self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags) def __del__(self): runtime.core.cuda_graphics_unregister_resource(self.context, self.resource) def map(self, dtype, shape) -> warp.array: """Map the OpenGL buffer to a Warp array. Args: dtype: The type of each element in the array. shape: The shape of the array. Returns: A Warp array object representing the mapped OpenGL buffer. """ runtime.core.cuda_graphics_map(self.context, self.resource) ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_size_t) ptr = ctypes.c_uint64(0) size = ctypes.c_size_t(0) runtime.core.cuda_graphics_device_ptr_and_size( self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size) ) return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device, owner=False) def unmap(self): """Unmap the OpenGL buffer.""" runtime.core.cuda_graphics_unmap(self.context, self.resource) def zeros( shape: Tuple = None, dtype=float, device: Devicelike = None, requires_grad: bool = False, pinned: bool = False, **kwargs, ) -> warp.array: """Return a zero-initialized array Args: shape: Array dimensions dtype: Type of each element, e.g.: warp.vec3, warp.mat33, etc device: Device that array will live on requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs) # use the CUDA default stream for synchronous behaviour with other streams with warp.ScopedStream(arr.device.null_stream): arr.zero_() return arr def zeros_like( src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None ) -> warp.array: """Return a zero-initialized array with the same type and dimension of another array Args: src: The template array to use for shape, data type, and device device: The device where the new array will be created (defaults to src.device) requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned) arr.zero_() return arr def full( shape: Tuple = None, value=0, dtype=Any, device: Devicelike = None, requires_grad: bool = False, pinned: bool = False, **kwargs, ) -> warp.array: """Return an array with all elements initialized to the given value Args: shape: Array dimensions value: Element value dtype: Type of each element, e.g.: float, warp.vec3, warp.mat33, etc device: Device that array will live on requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ if dtype == Any: # determine dtype from value value_type = type(value) if value_type == int: dtype = warp.int32 elif value_type == float: dtype = warp.float32 elif value_type in warp.types.scalar_types or hasattr(value_type, "_wp_scalar_type_"): dtype = value_type elif isinstance(value, warp.codegen.StructInstance): dtype = value._cls elif hasattr(value, "__len__"): # a sequence, assume it's a vector or matrix value try: # try to convert to a numpy array first na = np.array(value, copy=False) except Exception as e: raise ValueError(f"Failed to interpret the value as a vector or matrix: {e}") # determine the scalar type scalar_type = warp.types.np_dtype_to_warp_type.get(na.dtype) if scalar_type is None: raise ValueError(f"Failed to convert {na.dtype} to a Warp data type") # determine if vector or matrix if na.ndim == 1: dtype = warp.types.vector(na.size, scalar_type) elif na.ndim == 2: dtype = warp.types.matrix(na.shape, scalar_type) else: raise ValueError("Values with more than two dimensions are not supported") else: raise ValueError(f"Invalid value type for Warp array: {value_type}") arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs) # use the CUDA default stream for synchronous behaviour with other streams with warp.ScopedStream(arr.device.null_stream): arr.fill_(value) return arr def full_like( src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None ) -> warp.array: """Return an array with all elements initialized to the given value with the same type and dimension of another array Args: src: The template array to use for shape, data type, and device value: Element value device: The device where the new array will be created (defaults to src.device) requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned) arr.fill_(value) return arr def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array: """Clone an existing array, allocates a copy of the src memory Args: src: The source array to copy device: The device where the new array will be created (defaults to src.device) requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned) warp.copy(arr, src) return arr def empty( shape: Tuple = None, dtype=float, device: Devicelike = None, requires_grad: bool = False, pinned: bool = False, **kwargs, ) -> warp.array: """Returns an uninitialized array Args: shape: Array dimensions dtype: Type of each element, e.g.: `warp.vec3`, `warp.mat33`, etc device: Device that array will live on requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ # backwards compatibility for case where users called wp.empty(n=length, ...) if "n" in kwargs: shape = (kwargs["n"],) del kwargs["n"] # ensure shape is specified, even if creating a zero-sized array if shape is None: shape = 0 return warp.array(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs) def empty_like( src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None ) -> warp.array: """Return an uninitialized array with the same type and dimension of another array Args: src: The template array to use for shape, data type, and device device: The device where the new array will be created (defaults to src.device) requires_grad: Whether the array will be tracked for back propagation pinned: Whether the array uses pinned host memory (only applicable to CPU arrays) Returns: A warp.array object representing the allocation """ if device is None: device = src.device if requires_grad is None: if hasattr(src, "requires_grad"): requires_grad = src.requires_grad else: requires_grad = False if pinned is None: if hasattr(src, "pinned"): pinned = src.pinned else: pinned = False arr = empty(shape=src.shape, dtype=src.dtype, device=device, requires_grad=requires_grad, pinned=pinned) return arr def from_numpy( arr: np.ndarray, dtype: Optional[type] = None, shape: Optional[Sequence[int]] = None, device: Optional[Devicelike] = None, requires_grad: bool = False, ) -> warp.array: if dtype is None: base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype) if base_type is None: raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype)) dim_count = len(arr.shape) if dim_count == 2: dtype = warp.types.vector(length=arr.shape[1], dtype=base_type) elif dim_count == 3: dtype = warp.types.matrix(shape=(arr.shape[1], arr.shape[2]), dtype=base_type) else: dtype = base_type return warp.array( data=arr, dtype=dtype, shape=shape, owner=False, device=device, requires_grad=requires_grad, ) # given a kernel destination argument type and a value convert # to a c-type that can be passed to a kernel def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False): if warp.types.is_array(arg_type): if value is None: # allow for NULL arrays return arg_type.__ctype__() else: # check for array type # - in forward passes, array types have to match # - in backward passes, indexed array gradients are regular arrays if adjoint: array_matches = isinstance(value, warp.array) else: array_matches = type(value) is type(arg_type) if not array_matches: adj = "adjoint " if adjoint else "" raise RuntimeError( f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}." ) # check subtype if not warp.types.types_equal(value.dtype, arg_type.dtype): adj = "adjoint " if adjoint else "" raise RuntimeError( f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with dtype={arg_type.dtype} but passed array has dtype={value.dtype}." ) # check dimensions if value.ndim != arg_type.ndim: adj = "adjoint " if adjoint else "" raise RuntimeError( f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with {arg_type.ndim} dimension(s) but the passed array has {value.ndim} dimension(s)." ) # check device # if a.device != device and not device.can_access(a.device): if value.device != device: raise RuntimeError( f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={value.device}." ) return value.__ctype__() elif isinstance(arg_type, warp.codegen.Struct): assert value is not None return value.__ctype__() # try to convert to a value type (vec3, mat33, etc) elif issubclass(arg_type, ctypes.Array): if warp.types.types_equal(type(value), arg_type): return value else: # try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case) try: return arg_type(value) except Exception: raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}") elif isinstance(value, bool): return ctypes.c_bool(value) elif isinstance(value, arg_type): try: # try to pack as a scalar type if arg_type is warp.types.float16: return arg_type._type_(warp.types.float_to_half_bits(value.value)) else: return arg_type._type_(value.value) except Exception: raise RuntimeError( "Error launching kernel, unable to pack kernel parameter type " f"{type(value)} for param {arg_name}, expected {arg_type}" ) else: try: # try to pack as a scalar type if arg_type is warp.types.float16: return arg_type._type_(warp.types.float_to_half_bits(value)) else: return arg_type._type_(value) except Exception as e: print(e) raise RuntimeError( "Error launching kernel, unable to pack kernel parameter type " f"{type(value)} for param {arg_name}, expected {arg_type}" ) # represents all data required for a kernel launch # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)` class Launch: def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0): # if not specified look up hooks if not hooks: module = kernel.module if not module.load(device): return hooks = module.get_kernel_hooks(kernel, device) # if not specified set a zero bound if not bounds: bounds = warp.types.launch_bounds_t(0) # if not specified then build a list of default value params for args if not params: params = [] params.append(bounds) for a in kernel.adj.args: if isinstance(a.type, warp.types.array): params.append(a.type.__ctype__()) elif isinstance(a.type, warp.codegen.Struct): params.append(a.type().__ctype__()) else: params.append(pack_arg(kernel, a.type, a.label, 0, device, False)) kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params] kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args) params_addr = kernel_params self.kernel = kernel self.hooks = hooks self.params = params self.params_addr = params_addr self.device = device self.bounds = bounds self.max_blocks = max_blocks def set_dim(self, dim): self.bounds = warp.types.launch_bounds_t(dim) # launch bounds always at index 0 self.params[0] = self.bounds # for CUDA kernels we need to update the address to each arg if self.params_addr: self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds)) # set kernel param at an index, will convert to ctype as necessary def set_param_at_index(self, index, value): arg_type = self.kernel.adj.args[index].type arg_name = self.kernel.adj.args[index].label carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, False) self.params[index + 1] = carg # for CUDA kernels we need to update the address to each arg if self.params_addr: self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(carg)) # set kernel param at an index without any type conversion # args must be passed as ctypes or basic int / float types def set_param_at_index_from_ctype(self, index, value): if isinstance(value, ctypes.Structure): # not sure how to directly assign struct->struct without reallocating using ctypes self.params[index + 1] = value # for CUDA kernels we need to update the address to each arg if self.params_addr: self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(value)) else: self.params[index + 1].__init__(value) # set kernel param by argument name def set_param_by_name(self, name, value): for i, arg in enumerate(self.kernel.adj.args): if arg.label == name: self.set_param_at_index(i, value) # set kernel param by argument name with no type conversions def set_param_by_name_from_ctype(self, name, value): # lookup argument index for i, arg in enumerate(self.kernel.adj.args): if arg.label == name: self.set_param_at_index_from_ctype(i, value) # set all params def set_params(self, values): for i, v in enumerate(values): self.set_param_at_index(i, v) # set all params without performing type-conversions def set_params_from_ctypes(self, values): for i, v in enumerate(values): self.set_param_at_index_from_ctype(i, v) def launch(self) -> Any: if self.device.is_cpu: self.hooks.forward(*self.params) else: runtime.core.cuda_launch_kernel( self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr ) def launch( kernel, dim: Tuple[int], inputs: List, outputs: List = [], adj_inputs: List = [], adj_outputs: List = [], device: Devicelike = None, stream: Stream = None, adjoint=False, record_tape=True, record_cmd=False, max_blocks=0, ): """Launch a Warp kernel on the target device Kernel launches are asynchronous with respect to the calling Python thread. Args: kernel: The name of a Warp kernel function, decorated with the ``@wp.kernel`` decorator dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints with max of 4 dimensions inputs: The input parameters to the kernel outputs: The output parameters (optional) adj_inputs: The adjoint inputs (optional) adj_outputs: The adjoint outputs (optional) device: The device to launch on (optional) stream: The stream to launch on (optional) adjoint: Whether to run forward or backward pass (typically use False) record_tape: When true the launch will be recorded the global wp.Tape() object when present record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()`` max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches. If negative or zero, the maximum hardware value will be used. """ assert_initialized() # if stream is specified, use the associated device if stream is not None: device = stream.device else: device = runtime.get_device(device) # check function is a Kernel if not isinstance(kernel, Kernel): raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.") # debugging aid if warp.config.print_launches: print(f"kernel: {kernel.key} dim: {dim} inputs: {inputs} outputs: {outputs} device: {device}") # construct launch bounds bounds = warp.types.launch_bounds_t(dim) if bounds.size > 0: # first param is the number of threads params = [] params.append(bounds) # converts arguments to kernel's expected ctypes and packs into params def pack_args(args, params, adjoint=False): for i, a in enumerate(args): arg_type = kernel.adj.args[i].type arg_name = kernel.adj.args[i].label params.append(pack_arg(kernel, arg_type, arg_name, a, device, adjoint)) fwd_args = inputs + outputs adj_args = adj_inputs + adj_outputs if (len(fwd_args)) != (len(kernel.adj.args)): raise RuntimeError( f"Error launching kernel '{kernel.key}', passed {len(fwd_args)} arguments but kernel requires {len(kernel.adj.args)}." ) # if it's a generic kernel, infer the required overload from the arguments if kernel.is_generic: fwd_types = kernel.infer_argument_types(fwd_args) kernel = kernel.get_overload(fwd_types) # delay load modules, including new overload if needed module = kernel.module if not module.load(device): return # late bind hooks = module.get_kernel_hooks(kernel, device) pack_args(fwd_args, params) pack_args(adj_args, params, adjoint=True) # run kernel if device.is_cpu: if adjoint: if hooks.backward is None: raise RuntimeError( f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'" ) hooks.backward(*params) else: if hooks.forward is None: raise RuntimeError( f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'" ) if record_cmd: launch = Launch( kernel=kernel, hooks=hooks, params=params, params_addr=None, bounds=bounds, device=device ) return launch else: hooks.forward(*params) else: kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params] kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args) with warp.ScopedStream(stream): if adjoint: if hooks.backward is None: raise RuntimeError( f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'" ) runtime.core.cuda_launch_kernel( device.context, hooks.backward, bounds.size, max_blocks, kernel_params ) else: if hooks.forward is None: raise RuntimeError( f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'" ) if record_cmd: launch = Launch( kernel=kernel, hooks=hooks, params=params, params_addr=kernel_params, bounds=bounds, device=device, ) return launch else: # launch runtime.core.cuda_launch_kernel( device.context, hooks.forward, bounds.size, max_blocks, kernel_params ) try: runtime.verify_cuda_device(device) except Exception as e: print(f"Error launching kernel: {kernel.key} on device {device}") raise e # record on tape if one is active if runtime.tape and record_tape: runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device) def synchronize(): """Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices This method allows the host application code to ensure that any kernel launches or memory copies have completed. """ if is_cuda_driver_initialized(): # save the original context to avoid side effects saved_context = runtime.core.cuda_context_get_current() # TODO: only synchronize devices that have outstanding work for device in runtime.cuda_devices: # avoid creating primary context if the device has not been used yet if device.has_context: if device.is_capturing: raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active") runtime.core.cuda_context_synchronize(device.context) # restore the original context to avoid side effects runtime.core.cuda_context_set_current(saved_context) def synchronize_device(device: Devicelike = None): """Manually synchronize the calling CPU thread with any outstanding CUDA work on the specified device This method allows the host application code to ensure that any kernel launches or memory copies have completed. Args: device: Device to synchronize. If None, synchronize the current CUDA device. """ device = runtime.get_device(device) if device.is_cuda: if device.is_capturing: raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active") runtime.core.cuda_context_synchronize(device.context) def synchronize_stream(stream_or_device=None): """Manually synchronize the calling CPU thread with any outstanding CUDA work on the specified stream. Args: stream_or_device: `wp.Stream` or a device. If the argument is a device, synchronize the device's current stream. """ if isinstance(stream_or_device, Stream): stream = stream_or_device else: stream = runtime.get_device(stream_or_device).stream runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream) def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None): """Force user-defined kernels to be compiled and loaded Args: device: The device or list of devices to load the modules on. If None, load on all devices. modules: List of modules to load. If None, load all imported modules. """ if is_cuda_driver_initialized(): # save original context to avoid side effects saved_context = runtime.core.cuda_context_get_current() if device is None: devices = get_devices() elif isinstance(device, list): devices = [get_device(device_item) for device_item in device] else: devices = [get_device(device)] if modules is None: modules = user_modules.values() for d in devices: for m in modules: m.load(d) if is_cuda_available(): # restore original context to avoid side effects runtime.core.cuda_context_set_current(saved_context) def load_module( module: Union[Module, ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False ): """Force user-defined module to be compiled and loaded Args: module: The module to load. If None, load the current module. device: The device to load the modules on. If None, load on all devices. recursive: Whether to load submodules. E.g., if the given module is `warp.sim`, this will also load `warp.sim.model`, `warp.sim.articulation`, etc. Note: A module must be imported before it can be loaded by this function. """ if module is None: # if module not specified, use the module that called us module = inspect.getmodule(inspect.stack()[1][0]) module_name = module.__name__ elif isinstance(module, Module): module_name = module.name elif isinstance(module, ModuleType): module_name = module.__name__ elif isinstance(module, str): module_name = module else: raise TypeError(f"Argument must be a module, got {type(module)}") modules = [] # add the given module, if found m = user_modules.get(module_name) if m is not None: modules.append(m) # add submodules, if recursive if recursive: prefix = module_name + "." for name, mod in user_modules.items(): if name.startswith(prefix): modules.append(mod) force_load(device=device, modules=modules) def set_module_options(options: Dict[str, Any], module: Optional[Any] = None): """Set options for the current module. Options can be used to control runtime compilation and code-generation for the current module individually. Available options are listed below. * **mode**: The compilation mode to use, can be "debug", or "release", defaults to the value of ``warp.config.mode``. * **max_unroll**: The maximum fixed-size loop to unroll (default 16) Args: options: Set of key-value option pairs """ if module is None: m = inspect.getmodule(inspect.stack()[1][0]) else: m = module get_module(m.__name__).options.update(options) get_module(m.__name__).unload() def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]: """Returns a list of options for the current module.""" if module is None: m = inspect.getmodule(inspect.stack()[1][0]) else: m = module return get_module(m.__name__).options def capture_begin(device: Devicelike = None, stream=None, force_module_load=None): """Begin capture of a CUDA graph Captures all subsequent kernel launches and memory operations on CUDA devices. This can be used to record large numbers of kernels and replay them with low-overhead. Args: device: The device to capture on, if None the current CUDA device will be used stream: The CUDA stream to capture on force_module_load: Whether or not to force loading of all kernels before capture, in general it is better to use :func:`~warp.load_module()` to selectively load kernels. """ if force_module_load is None: force_module_load = warp.config.graph_capture_module_load_default if warp.config.verify_cuda: raise RuntimeError("Cannot use CUDA error verification during graph capture") if stream is not None: device = stream.device else: device = runtime.get_device(device) if not device.is_cuda: raise RuntimeError("Must be a CUDA device") if force_module_load: force_load(device) device.is_capturing = True # disable garbage collection to avoid older allocations getting collected during graph capture gc.disable() with warp.ScopedStream(stream): runtime.core.cuda_graph_begin_capture(device.context) def capture_end(device: Devicelike = None, stream=None) -> Graph: """Ends the capture of a CUDA graph Returns: A handle to a CUDA graph object that can be launched with :func:`~warp.capture_launch()` """ if stream is not None: device = stream.device else: device = runtime.get_device(device) if not device.is_cuda: raise RuntimeError("Must be a CUDA device") with warp.ScopedStream(stream): graph = runtime.core.cuda_graph_end_capture(device.context) device.is_capturing = False # re-enable GC gc.enable() if graph is None: raise RuntimeError( "Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event." ) else: return Graph(device, graph) def capture_launch(graph: Graph, stream: Stream = None): """Launch a previously captured CUDA graph Args: graph: A Graph as returned by :func:`~warp.capture_end()` stream: A Stream to launch the graph on (optional) """ if stream is not None: if stream.device != graph.device: raise RuntimeError(f"Cannot launch graph from device {graph.device} on stream from device {stream.device}") device = stream.device else: device = graph.device with warp.ScopedStream(stream): runtime.core.cuda_graph_launch(device.context, graph.exec) def copy( dest: warp.array, src: warp.array, dest_offset: int = 0, src_offset: int = 0, count: int = 0, stream: Stream = None ): """Copy array contents from src to dest Args: dest: Destination array, must be at least as big as source buffer src: Source array dest_offset: Element offset in the destination array src_offset: Element offset in the source array count: Number of array elements to copy (will copy all elements if set to 0) stream: The stream on which to perform the copy (optional) """ if not warp.types.is_array(src) or not warp.types.is_array(dest): raise RuntimeError("Copy source and destination must be arrays") # backwards compatibility, if count is zero then copy entire src array if count <= 0: count = src.size if count == 0: return # copying non-contiguous arrays requires that they are on the same device if not (src.is_contiguous and dest.is_contiguous) and src.device != dest.device: if dest.is_contiguous: # make a contiguous copy of the source array src = src.contiguous() else: # make a copy of the source array on the destination device src = src.to(dest.device) if src.is_contiguous and dest.is_contiguous: bytes_to_copy = count * warp.types.type_size_in_bytes(src.dtype) src_size_in_bytes = src.size * warp.types.type_size_in_bytes(src.dtype) dst_size_in_bytes = dest.size * warp.types.type_size_in_bytes(dest.dtype) src_offset_in_bytes = src_offset * warp.types.type_size_in_bytes(src.dtype) dst_offset_in_bytes = dest_offset * warp.types.type_size_in_bytes(dest.dtype) src_ptr = src.ptr + src_offset_in_bytes dst_ptr = dest.ptr + dst_offset_in_bytes if src_offset_in_bytes + bytes_to_copy > src_size_in_bytes: raise RuntimeError( f"Trying to copy source buffer with size ({bytes_to_copy}) from offset ({src_offset_in_bytes}) is larger than source size ({src_size_in_bytes})" ) if dst_offset_in_bytes + bytes_to_copy > dst_size_in_bytes: raise RuntimeError( f"Trying to copy source buffer with size ({bytes_to_copy}) to offset ({dst_offset_in_bytes}) is larger than destination size ({dst_size_in_bytes})" ) if src.device.is_cpu and dest.device.is_cpu: runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy) else: # figure out the CUDA context/stream for the copy if stream is not None: copy_device = stream.device elif dest.device.is_cuda: copy_device = dest.device else: copy_device = src.device with warp.ScopedStream(stream): if src.device.is_cpu and dest.device.is_cuda: runtime.core.memcpy_h2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy) elif src.device.is_cuda and dest.device.is_cpu: runtime.core.memcpy_d2h(copy_device.context, dst_ptr, src_ptr, bytes_to_copy) elif src.device.is_cuda and dest.device.is_cuda: if src.device == dest.device: runtime.core.memcpy_d2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy) else: runtime.core.memcpy_peer(copy_device.context, dst_ptr, src_ptr, bytes_to_copy) else: raise RuntimeError("Unexpected source and destination combination") else: # handle non-contiguous and indexed arrays if src.shape != dest.shape: raise RuntimeError("Incompatible array shapes") src_elem_size = warp.types.type_size_in_bytes(src.dtype) dst_elem_size = warp.types.type_size_in_bytes(dest.dtype) if src_elem_size != dst_elem_size: raise RuntimeError("Incompatible array data types") # can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths # TODO? if ( isinstance(src, (warp.fabricarray, warp.indexedfabricarray)) and src.ndim > 1 or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray)) and dest.ndim > 1 ): raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported") src_desc = src.__ctype__() dst_desc = dest.__ctype__() src_ptr = ctypes.pointer(src_desc) dst_ptr = ctypes.pointer(dst_desc) src_type = warp.types.array_type_id(src) dst_type = warp.types.array_type_id(dest) if src.device.is_cuda: with warp.ScopedStream(stream): runtime.core.array_copy_device(src.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size) else: runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size) # copy gradient, if needed if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None: copy(dest.grad, src.grad, stream=stream) def type_str(t): if t is None: return "None" elif t == Any: return "Any" elif t == Callable: return "Callable" elif t == Tuple[int, int]: return "Tuple[int, int]" elif isinstance(t, int): return str(t) elif isinstance(t, List): return "Tuple[" + ", ".join(map(type_str, t)) + "]" elif isinstance(t, warp.array): return f"Array[{type_str(t.dtype)}]" elif isinstance(t, warp.indexedarray): return f"IndexedArray[{type_str(t.dtype)}]" elif isinstance(t, warp.fabricarray): return f"FabricArray[{type_str(t.dtype)}]" elif isinstance(t, warp.indexedfabricarray): return f"IndexedFabricArray[{type_str(t.dtype)}]" elif hasattr(t, "_wp_generic_type_str_"): generic_type = t._wp_generic_type_str_ # for concrete vec/mat types use the short name if t in warp.types.vector_types: return t.__name__ # for generic vector / matrix type use a Generic type hint if generic_type == "vec_t": # return f"Vector" return f"Vector[{type_str(t._wp_type_params_[0])},{type_str(t._wp_scalar_type_)}]" elif generic_type == "quat_t": # return f"Quaternion" return f"Quaternion[{type_str(t._wp_scalar_type_)}]" elif generic_type == "mat_t": # return f"Matrix" return f"Matrix[{type_str(t._wp_type_params_[0])},{type_str(t._wp_type_params_[1])},{type_str(t._wp_scalar_type_)}]" elif generic_type == "transform_t": # return f"Transformation" return f"Transformation[{type_str(t._wp_scalar_type_)}]" else: raise TypeError("Invalid vector or matrix dimensions") else: return t.__name__ def print_function(f, file, noentry=False): # pragma: no cover """Writes a function definition to a file for use in reST documentation Args: f: The function being written file: The file object for output noentry: If True, then the :noindex: and :nocontentsentry: directive options will be added Returns: A bool indicating True if f was written to file """ if f.hidden: return False args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items()) return_type = "" try: # todo: construct a default value for each of the functions args # so we can generate the return type for overloaded functions return_type = " -> " + type_str(f.value_func(None, None, None)) except Exception: pass print(f".. function:: {f.key}({args}){return_type}", file=file) if noentry: print(" :noindex:", file=file) print(" :nocontentsentry:", file=file) print("", file=file) if f.doc != "": if not f.missing_grad: print(f" {f.doc}", file=file) else: print(f" {f.doc} [1]_", file=file) print("", file=file) print(file=file) return True def export_functions_rst(file): # pragma: no cover header = ( "..\n" " Autogenerated File - Do not edit. Run build_docs.py to generate.\n" "\n" ".. functions:\n" ".. currentmodule:: warp\n" "\n" "Kernel Reference\n" "================" ) print(header, file=file) # type definitions of all functions by group print("\nScalar Types", file=file) print("------------", file=file) for t in warp.types.scalar_types: print(f".. class:: {t.__name__}", file=file) # Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types: print(f".. class:: {warp.types.bool.__name__}", file=file) print("\n\nVector Types", file=file) print("------------", file=file) for t in warp.types.vector_types: print(f".. class:: {t.__name__}", file=file) print("\nGeneric Types", file=file) print("-------------", file=file) print(".. class:: Int", file=file) print(".. class:: Float", file=file) print(".. class:: Scalar", file=file) print(".. class:: Vector", file=file) print(".. class:: Matrix", file=file) print(".. class:: Quaternion", file=file) print(".. class:: Transformation", file=file) print(".. class:: Array", file=file) print("\nQuery Types", file=file) print("-------------", file=file) print(".. autoclass:: bvh_query_t", file=file) print(".. autoclass:: hash_grid_query_t", file=file) print(".. autoclass:: mesh_query_aabb_t", file=file) print(".. autoclass:: mesh_query_point_t", file=file) print(".. autoclass:: mesh_query_ray_t", file=file) # build dictionary of all functions by group groups = {} for k, f in builtin_functions.items(): # build dict of groups if f.group not in groups: groups[f.group] = [] # append all overloads to the group for o in f.overloads: groups[f.group].append(o) # Keep track of what function names have been written written_functions = {} for k, g in groups.items(): print("\n", file=file) print(k, file=file) print("---------------", file=file) for f in g: if f.key in written_functions: # Add :noindex: + :nocontentsentry: since Sphinx gets confused print_function(f, file=file, noentry=True) else: if print_function(f, file=file): written_functions[f.key] = [] # footnotes print(".. rubric:: Footnotes", file=file) print(".. [1] Note: function gradients not implemented for backpropagation.", file=file) def export_stubs(file): # pragma: no cover """Generates stub file for auto-complete of builtin functions""" import textwrap print( "# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc", file=file, ) print("", file=file) print("from typing import Any", file=file) print("from typing import Tuple", file=file) print("from typing import Callable", file=file) print("from typing import TypeVar", file=file) print("from typing import Generic", file=file) print("from typing import overload as over", file=file) print(file=file) # type hints, these need to be mirrored into the stubs file print('Length = TypeVar("Length", bound=int)', file=file) print('Rows = TypeVar("Rows", bound=int)', file=file) print('Cols = TypeVar("Cols", bound=int)', file=file) print('DType = TypeVar("DType")', file=file) print('Int = TypeVar("Int")', file=file) print('Float = TypeVar("Float")', file=file) print('Scalar = TypeVar("Scalar")', file=file) print("Vector = Generic[Length, Scalar]", file=file) print("Matrix = Generic[Rows, Cols, Scalar]", file=file) print("Quaternion = Generic[Float]", file=file) print("Transformation = Generic[Float]", file=file) print("Array = Generic[DType]", file=file) print("FabricArray = Generic[DType]", file=file) print("IndexedFabricArray = Generic[DType]", file=file) # prepend __init__.py with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file: # strip comment lines lines = [line for line in header_file if not line.startswith("#")] header = "".join(lines) print(header, file=file) print(file=file) for k, g in builtin_functions.items(): for f in g.overloads: args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items()) return_str = "" if not f.export or f.hidden: # or f.generic: continue try: # todo: construct a default value for each of the functions args # so we can generate the return type for overloaded functions return_type = f.value_func(None, None, None) if return_type: return_str = " -> " + type_str(return_type) except Exception: pass print("@over", file=file) print(f"def {f.key}({args}){return_str}:", file=file) print(' """', file=file) print(textwrap.indent(text=f.doc, prefix=" "), file=file) print(' """', file=file) print(" ...\n\n", file=file) def export_builtins(file: io.TextIOBase): # pragma: no cover def ctype_arg_str(t): if isinstance(t, int): return "int" elif isinstance(t, float): return "float" elif t in warp.types.vector_types: return f"{t.__name__}&" else: return t.__name__ def ctype_ret_str(t): if isinstance(t, int): return "int" elif isinstance(t, float): return "float" else: return t.__name__ file.write("namespace wp {\n\n") file.write('extern "C" {\n\n') for k, g in builtin_functions.items(): for f in g.overloads: if not f.export or f.generic: continue simple = True for k, v in f.input_types.items(): if isinstance(v, warp.array) or v == Any or v == Callable or v == Tuple: simple = False break # only export simple types that don't use arrays # or templated types if not simple or f.variadic: continue args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items()) params = ", ".join(f.input_types.keys()) return_type = "" try: # todo: construct a default value for each of the functions args # so we can generate the return type for overloaded functions return_type = ctype_ret_str(f.value_func(None, None, None)) except Exception: continue if return_type.startswith("Tuple"): continue if args == "": file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n") elif return_type == "None": file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n") else: file.write( f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n" ) file.write('\n} // extern "C"\n\n') file.write("} // namespace wp\n") # initialize global runtime runtime = None def init(): """Initialize the Warp runtime. This function must be called before any other API call. If an error occurs an exception will be raised.""" global runtime if runtime is None: runtime = Runtime()