qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import ast
import ctypes
import gc
import hashlib
import inspect
import io
import os
import platform
import sys
import types
from copy import copy as shallowcopy
from types import ModuleType
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import warp
import warp.build
import warp.codegen
import warp.config
# represents either a built-in or user-defined function
def create_value_func(type):
def value_func(args, kwds, templates):
return type
return value_func
def get_function_args(func):
"""Ensures that all function arguments are annotated and returns a dictionary mapping from argument name to its type."""
import inspect
argspec = inspect.getfullargspec(func)
# use source-level argument annotations
if len(argspec.annotations) < len(argspec.args):
raise RuntimeError(f"Incomplete argument annotations on function {func.__qualname__}")
return argspec.annotations
class Function:
def __init__(
self,
func,
key,
namespace,
input_types=None,
value_func=None,
template_func=None,
module=None,
variadic=False,
initializer_list_func=None,
export=False,
doc="",
group="",
hidden=False,
skip_replay=False,
missing_grad=False,
generic=False,
native_func=None,
defaults=None,
custom_replay_func=None,
native_snippet=None,
adj_native_snippet=None,
skip_forward_codegen=False,
skip_reverse_codegen=False,
custom_reverse_num_input_args=-1,
custom_reverse_mode=False,
overloaded_annotations=None,
code_transformers=[],
skip_adding_overload=False,
require_original_output_arg=False,
):
self.func = func # points to Python function decorated with @wp.func, may be None for builtins
self.key = key
self.namespace = namespace
self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
self.template_func = template_func
self.input_types = {}
self.export = export
self.doc = doc
self.group = group
self.module = module
self.variadic = variadic # function can take arbitrary number of inputs, e.g.: printf()
self.defaults = defaults
# Function instance for a custom implementation of the replay pass
self.custom_replay_func = custom_replay_func
self.native_snippet = native_snippet
self.adj_native_snippet = adj_native_snippet
self.custom_grad_func = None
self.require_original_output_arg = require_original_output_arg
if initializer_list_func is None:
self.initializer_list_func = lambda x, y: False
else:
self.initializer_list_func = (
initializer_list_func # True if the arguments should be emitted as an initializer list in the c++ code
)
self.hidden = hidden # function will not be listed in docs
self.skip_replay = (
skip_replay # whether or not operation will be performed during the forward replay in the backward pass
)
self.missing_grad = missing_grad # whether or not builtin is missing a corresponding adjoint
self.generic = generic
# allow registering builtin functions with a different name in Python from the native code
if native_func is None:
self.native_func = key
else:
self.native_func = native_func
if func:
# user-defined function
# generic and concrete overload lookups by type signature
self.user_templates = {}
self.user_overloads = {}
# user defined (Python) function
self.adj = warp.codegen.Adjoint(
func,
is_user_function=True,
skip_forward_codegen=skip_forward_codegen,
skip_reverse_codegen=skip_reverse_codegen,
custom_reverse_num_input_args=custom_reverse_num_input_args,
custom_reverse_mode=custom_reverse_mode,
overload_annotations=overloaded_annotations,
transformers=code_transformers,
)
# record input types
for name, type in self.adj.arg_types.items():
if name == "return":
self.value_func = create_value_func(type)
else:
self.input_types[name] = type
else:
# builtin function
# embedded linked list of all overloads
# the builtin_functions dictionary holds
# the list head for a given key (func name)
self.overloads = []
# builtin (native) function, canonicalize argument types
for k, v in input_types.items():
self.input_types[k] = warp.types.type_to_warp(v)
# cache mangled name
if self.is_simple():
self.mangled_name = self.mangle()
else:
self.mangled_name = None
if not skip_adding_overload:
self.add_overload(self)
# add to current module
if module:
module.register_function(self, skip_adding_overload)
def __call__(self, *args, **kwargs):
# handles calling a builtin (native) function
# as if it was a Python function, i.e.: from
# within the CPython interpreter rather than
# from within a kernel (experimental).
if self.is_builtin() and self.mangled_name:
# For each of this function's existing overloads, we attempt to pack
# the given arguments into the C types expected by the corresponding
# parameters, and we rinse and repeat until we get a match.
for overload in self.overloads:
if overload.generic:
continue
success, return_value = call_builtin(overload, *args)
if success:
return return_value
# overload resolution or call failed
raise RuntimeError(
f"Couldn't find a function '{self.key}' compatible with "
f"the arguments '{', '.join(type(x).__name__ for x in args)}'"
)
if hasattr(self, "user_overloads") and len(self.user_overloads):
# user-defined function with overloads
if len(kwargs):
raise RuntimeError(
f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
)
# try and find a matching overload
for overload in self.user_overloads.values():
if len(overload.input_types) != len(args):
continue
template_types = list(overload.input_types.values())
arg_names = list(overload.input_types.keys())
try:
# attempt to unify argument types with function template types
warp.types.infer_argument_types(args, template_types, arg_names)
return overload.func(*args)
except Exception:
continue
raise RuntimeError(f"Error calling function '{self.key}', no overload found for arguments {args}")
# user-defined function with no overloads
if self.func is None:
raise RuntimeError(f"Error calling function '{self.key}', function is undefined")
# this function has no overloads, call it like a plain Python function
return self.func(*args, **kwargs)
def is_builtin(self):
return self.func is None
def is_simple(self):
if self.variadic:
return False
# only export simple types that don't use arrays
for k, v in self.input_types.items():
if isinstance(v, warp.array) or v == Any or v == Callable or v == Tuple:
return False
return_type = ""
try:
# todo: construct a default value for each of the functions args
# so we can generate the return type for overloaded functions
return_type = type_str(self.value_func(None, None, None))
except Exception:
return False
if return_type.startswith("Tuple"):
return False
return True
def mangle(self):
# builds a mangled name for the C-exported
# function, e.g.: builtin_normalize_vec3()
name = "builtin_" + self.key
types = []
for t in self.input_types.values():
types.append(t.__name__)
return "_".join([name, *types])
def add_overload(self, f):
if self.is_builtin():
# todo: note that it is an error to add two functions
# with the exact same signature as this would cause compile
# errors during compile time. We should check here if there
# is a previously created function with the same signature
self.overloads.append(f)
# make sure variadic overloads appear last so non variadic
# ones are matched first:
self.overloads.sort(key=lambda f: f.variadic)
else:
# get function signature based on the input types
sig = warp.types.get_signature(
f.input_types.values(), func_name=f.key, arg_names=list(f.input_types.keys())
)
# check if generic
if warp.types.is_generic_signature(sig):
if sig in self.user_templates:
raise RuntimeError(
f"Duplicate generic function overload {self.key} with arguments {f.input_types.values()}"
)
self.user_templates[sig] = f
else:
if sig in self.user_overloads:
raise RuntimeError(
f"Duplicate function overload {self.key} with arguments {f.input_types.values()}"
)
self.user_overloads[sig] = f
def get_overload(self, arg_types):
assert not self.is_builtin()
sig = warp.types.get_signature(arg_types, func_name=self.key)
f = self.user_overloads.get(sig)
if f is not None:
return f
else:
for f in self.user_templates.values():
if len(f.input_types) != len(arg_types):
continue
# try to match the given types to the function template types
template_types = list(f.input_types.values())
args_matched = True
for i in range(len(arg_types)):
if not warp.types.type_matches_template(arg_types[i], template_types[i]):
args_matched = False
break
if args_matched:
# instantiate this function with the specified argument types
arg_names = f.input_types.keys()
overload_annotations = dict(zip(arg_names, arg_types))
ovl = shallowcopy(f)
ovl.adj = warp.codegen.Adjoint(f.func, overload_annotations)
ovl.input_types = overload_annotations
ovl.value_func = None
self.user_overloads[sig] = ovl
return ovl
# failed to find overload
return None
def __repr__(self):
inputs_str = ", ".join([f"{k}: {warp.types.type_repr(v)}" for k, v in self.input_types.items()])
return f"<Function {self.key}({inputs_str})>"
def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
uses_non_warp_array_type = False
# Retrieve the built-in function from Warp's dll.
c_func = getattr(warp.context.runtime.core, func.mangled_name)
# Try gathering the parameters that the function expects and pack them
# into their corresponding C types.
c_params = []
for i, (_, arg_type) in enumerate(func.input_types.items()):
param = params[i]
try:
iter(param)
except TypeError:
is_array = False
else:
is_array = True
if is_array:
if not issubclass(arg_type, ctypes.Array):
return (False, None)
# The argument expects a built-in Warp type like a vector or a matrix.
c_param = None
if isinstance(param, ctypes.Array):
# The given parameter is also a built-in Warp type, so we only need
# to make sure that it matches with the argument.
if not warp.types.types_equal(type(param), arg_type):
return (False, None)
if isinstance(param, arg_type):
c_param = param
else:
# Cast the value to its argument type to make sure that it
# can be assigned to the field of the `Param` struct.
# This could error otherwise when, for example, the field type
# is set to `vec3i` while the value is of type `vector(length=3, dtype=int)`,
# even though both types are semantically identical.
c_param = arg_type(param)
else:
# Flatten the parameter values into a flat 1-D array.
arr = []
ndim = 1
stack = [(0, param)]
while stack:
depth, elem = stack.pop(0)
try:
# If `elem` is a sequence, then it should be possible
# to add its elements to the stack for later processing.
stack.extend((depth + 1, x) for x in elem)
except TypeError:
# Since `elem` doesn't seem to be a sequence,
# we must have a leaf value that we need to add to our
# resulting array.
arr.append(elem)
ndim = max(depth, ndim)
assert ndim > 0
# Ensure that if the given parameter value is, say, a 2-D array,
# then we try to resolve it against a matrix argument rather than
# a vector.
if ndim > len(arg_type._shape_):
return (False, None)
elem_count = len(arr)
if elem_count != arg_type._length_:
return (False, None)
# Retrieve the element type of the sequence while ensuring
# that it's homogeneous.
elem_type = type(arr[0])
for i in range(1, elem_count):
if type(arr[i]) is not elem_type:
raise ValueError("All array elements must share the same type.")
expected_elem_type = arg_type._wp_scalar_type_
if not (
elem_type is expected_elem_type
or (elem_type is float and expected_elem_type is warp.types.float32)
or (elem_type is int and expected_elem_type is warp.types.int32)
or (
issubclass(elem_type, np.number)
and warp.types.np_dtype_to_warp_type[np.dtype(elem_type)] is expected_elem_type
)
):
# The parameter value has a type not matching the type defined
# for the corresponding argument.
return (False, None)
if elem_type in warp.types.int_types:
# Pass the value through the expected integer type
# in order to evaluate any integer wrapping.
# For example `uint8(-1)` should result in the value `-255`.
arr = tuple(elem_type._type_(x.value).value for x in arr)
elif elem_type in warp.types.float_types:
# Extract the floating-point values.
arr = tuple(x.value for x in arr)
c_param = arg_type()
if warp.types.type_is_matrix(arg_type):
rows, cols = arg_type._shape_
for i in range(rows):
idx_start = i * cols
idx_end = idx_start + cols
c_param[i] = arr[idx_start:idx_end]
else:
c_param[:] = arr
uses_non_warp_array_type = True
c_params.append(ctypes.byref(c_param))
else:
if issubclass(arg_type, ctypes.Array):
return (False, None)
if not (
isinstance(param, arg_type)
or (type(param) is float and arg_type is warp.types.float32)
or (type(param) is int and arg_type is warp.types.int32)
or warp.types.np_dtype_to_warp_type.get(getattr(param, "dtype", None)) is arg_type
):
return (False, None)
if type(param) in warp.types.scalar_types:
param = param.value
# try to pack as a scalar type
if arg_type == warp.types.float16:
c_params.append(arg_type._type_(warp.types.float_to_half_bits(param)))
else:
c_params.append(arg_type._type_(param))
# returns the corresponding ctype for a scalar or vector warp type
value_type = func.value_func(None, None, None)
if value_type == float:
value_ctype = ctypes.c_float
elif value_type == int:
value_ctype = ctypes.c_int32
elif issubclass(value_type, (ctypes.Array, ctypes.Structure)):
value_ctype = value_type
else:
# scalar type
value_ctype = value_type._type_
# construct return value (passed by address)
ret = value_ctype()
ret_addr = ctypes.c_void_p(ctypes.addressof(ret))
c_params.append(ret_addr)
# Call the built-in function from Warp's dll.
c_func(*c_params)
# TODO: uncomment when we have a way to print warning messages only once.
# if uses_non_warp_array_type:
# warp.utils.warn(
# "Support for built-in functions called with non-Warp array types, "
# "such as lists, tuples, NumPy arrays, and others, will be dropped "
# "in the future. Use a Warp type such as `wp.vec`, `wp.mat`, "
# "`wp.quat`, or `wp.transform`.",
# DeprecationWarning,
# stacklevel=3
# )
if issubclass(value_ctype, ctypes.Array) or issubclass(value_ctype, ctypes.Structure):
# return vector types as ctypes
return (True, ret)
if value_type == warp.types.float16:
return (True, warp.types.half_bits_to_float(ret.value))
# return scalar types as int/float
return (True, ret.value)
class KernelHooks:
def __init__(self, forward, backward):
self.forward = forward
self.backward = backward
# caches source and compiled entry points for a kernel (will be populated after module loads)
class Kernel:
def __init__(self, func, key=None, module=None, options=None, code_transformers=[]):
self.func = func
if module is None:
self.module = get_module(func.__module__)
else:
self.module = module
if key is None:
unique_key = self.module.generate_unique_kernel_key(func.__name__)
self.key = unique_key
else:
self.key = key
self.options = {} if options is None else options
self.adj = warp.codegen.Adjoint(func, transformers=code_transformers)
# check if generic
self.is_generic = False
for arg_type in self.adj.arg_types.values():
if warp.types.type_is_generic(arg_type):
self.is_generic = True
break
# unique signature (used to differentiate instances of generic kernels during codegen)
self.sig = ""
# known overloads for generic kernels, indexed by type signature
self.overloads = {}
# argument indices by name
self.arg_indices = dict((a.label, i) for i, a in enumerate(self.adj.args))
if self.module:
self.module.register_kernel(self)
def infer_argument_types(self, args):
template_types = list(self.adj.arg_types.values())
if len(args) != len(template_types):
raise RuntimeError(f"Invalid number of arguments for kernel {self.key}")
arg_names = list(self.adj.arg_types.keys())
return warp.types.infer_argument_types(args, template_types, arg_names)
def add_overload(self, arg_types):
if len(arg_types) != len(self.adj.arg_types):
raise RuntimeError(f"Invalid number of arguments for kernel {self.key}")
arg_names = list(self.adj.arg_types.keys())
template_types = list(self.adj.arg_types.values())
# make sure all argument types are concrete and match the kernel parameters
for i in range(len(arg_types)):
if not warp.types.type_matches_template(arg_types[i], template_types[i]):
if warp.types.type_is_generic(arg_types[i]):
raise TypeError(
f"Kernel {self.key} argument '{arg_names[i]}' cannot be generic, got {arg_types[i]}"
)
else:
raise TypeError(
f"Kernel {self.key} argument '{arg_names[i]}' type mismatch: expected {template_types[i]}, got {arg_types[i]}"
)
# get a type signature from the given argument types
sig = warp.types.get_signature(arg_types, func_name=self.key)
if sig in self.overloads:
raise RuntimeError(
f"Duplicate overload for kernel {self.key}, an overload with the given arguments already exists"
)
overload_annotations = dict(zip(arg_names, arg_types))
# instantiate this kernel with the given argument types
ovl = shallowcopy(self)
ovl.adj = warp.codegen.Adjoint(self.func, overload_annotations)
ovl.is_generic = False
ovl.overloads = {}
ovl.sig = sig
self.overloads[sig] = ovl
self.module.unload()
return ovl
def get_overload(self, arg_types):
sig = warp.types.get_signature(arg_types, func_name=self.key)
ovl = self.overloads.get(sig)
if ovl is not None:
return ovl
else:
return self.add_overload(arg_types)
def get_mangled_name(self):
if self.sig:
return f"{self.key}_{self.sig}"
else:
return self.key
# ----------------------
# decorator to register function, @func
def func(f):
name = warp.codegen.make_full_qualified_name(f)
m = get_module(f.__module__)
Function(
func=f, key=name, namespace="", module=m, value_func=None
) # value_type not known yet, will be inferred during Adjoint.build()
# return the top of the list of overloads for this key
return m.functions[name]
def func_native(snippet, adj_snippet=None):
"""
Decorator to register native code snippet, @func_native
"""
def snippet_func(f):
name = warp.codegen.make_full_qualified_name(f)
m = get_module(f.__module__)
func = Function(
func=f, key=name, namespace="", module=m, native_snippet=snippet, adj_native_snippet=adj_snippet
) # cuda snippets do not have a return value_type
return m.functions[name]
return snippet_func
def func_grad(forward_fn):
"""
Decorator to register a custom gradient function for a given forward function.
The function signature must correspond to one of the function overloads in the following way:
the first part of the input arguments are the original input variables with the same types as their
corresponding arguments in the original function, and the second part of the input arguments are the
adjoint variables of the output variables (if available) of the original function with the same types as the
output variables. The function must not return anything.
"""
def wrapper(grad_fn):
generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
if generic:
raise RuntimeError(
f"Cannot define custom grad definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
)
reverse_args = {}
reverse_args.update(forward_fn.input_types)
# create temporary Adjoint instance to analyze the function signature
adj = warp.codegen.Adjoint(
grad_fn, skip_forward_codegen=True, skip_reverse_codegen=False, transformers=forward_fn.adj.transformers
)
from warp.types import types_equal
grad_args = adj.args
grad_sig = warp.types.get_signature([arg.type for arg in grad_args], func_name=forward_fn.key)
generic = any(warp.types.type_is_generic(x.type) for x in grad_args)
if generic:
raise RuntimeError(
f"Cannot define custom grad definition for {forward_fn.key} since the provided grad function has generic input arguments."
)
def match_function(f):
# check whether the function overload f matches the signature of the provided gradient function
if not hasattr(f.adj, "return_var"):
f.adj.build(None)
expected_args = list(f.input_types.items())
if f.adj.return_var is not None:
expected_args += [(f"adj_ret_{var.label}", var.type) for var in f.adj.return_var]
if len(grad_args) != len(expected_args):
return False
if any(not types_equal(a.type, exp_type) for a, (_, exp_type) in zip(grad_args, expected_args)):
return False
return True
def add_custom_grad(f: Function):
# register custom gradient function
f.custom_grad_func = Function(
grad_fn,
key=f.key,
namespace=f.namespace,
input_types=reverse_args,
value_func=None,
module=f.module,
template_func=f.template_func,
skip_forward_codegen=True,
custom_reverse_mode=True,
custom_reverse_num_input_args=len(f.input_types),
skip_adding_overload=False,
code_transformers=f.adj.transformers,
)
f.adj.skip_reverse_codegen = True
if hasattr(forward_fn, "user_overloads") and len(forward_fn.user_overloads):
# find matching overload for which this grad function is defined
for sig, f in forward_fn.user_overloads.items():
if not grad_sig.startswith(sig):
continue
if match_function(f):
add_custom_grad(f)
return
raise RuntimeError(
f"No function overload found for gradient function {grad_fn.__qualname__} for function {forward_fn.key}"
)
else:
# resolve return variables
forward_fn.adj.build(None)
expected_args = list(forward_fn.input_types.items())
if forward_fn.adj.return_var is not None:
expected_args += [(f"adj_ret_{var.label}", var.type) for var in forward_fn.adj.return_var]
# check if the signature matches this function
if match_function(forward_fn):
add_custom_grad(forward_fn)
else:
raise RuntimeError(
f"Gradient function {grad_fn.__qualname__} for function {forward_fn.key} has an incorrect signature. The arguments must match the "
"forward function arguments plus the adjoint variables corresponding to the return variables:"
f"\n{', '.join(map(lambda nt: f'{nt[0]}: {nt[1].__name__}', expected_args))}"
)
return wrapper
def func_replay(forward_fn):
"""
Decorator to register a custom replay function for a given forward function.
The replay function is the function version that is called in the forward phase of the backward pass (replay mode) and corresponds to the forward function by default.
The provided function has to match the signature of one of the original forward function overloads.
"""
def wrapper(replay_fn):
generic = any(warp.types.type_is_generic(x) for x in forward_fn.input_types.values())
if generic:
raise RuntimeError(
f"Cannot define custom replay definition for {forward_fn.key} since functions with generic input arguments are not yet supported."
)
args = get_function_args(replay_fn)
arg_types = list(args.values())
generic = any(warp.types.type_is_generic(x) for x in arg_types)
if generic:
raise RuntimeError(
f"Cannot define custom replay definition for {forward_fn.key} since the provided replay function has generic input arguments."
)
f = forward_fn.get_overload(arg_types)
if f is None:
inputs_str = ", ".join([f"{k}: {v.__name__}" for k, v in args.items()])
raise RuntimeError(
f"Could not find forward definition of function {forward_fn.key} that matches custom replay definition with arguments:\n{inputs_str}"
)
f.custom_replay_func = Function(
replay_fn,
key=f"replay_{f.key}",
namespace=f.namespace,
input_types=f.input_types,
value_func=f.value_func,
module=f.module,
template_func=f.template_func,
skip_reverse_codegen=True,
skip_adding_overload=True,
code_transformers=f.adj.transformers,
)
return wrapper
# decorator to register kernel, @kernel, custom_name may be a string
# that creates a kernel with a different name from the actual function
def kernel(f=None, *, enable_backward=None):
def wrapper(f, *args, **kwargs):
options = {}
if enable_backward is not None:
options["enable_backward"] = enable_backward
m = get_module(f.__module__)
k = Kernel(
func=f,
key=warp.codegen.make_full_qualified_name(f),
module=m,
options=options,
)
return k
if f is None:
# Arguments were passed to the decorator.
return wrapper
return wrapper(f)
# decorator to register struct, @struct
def struct(c):
m = get_module(c.__module__)
s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
return s
# overload a kernel with the given argument types
def overload(kernel, arg_types=None):
if isinstance(kernel, Kernel):
# handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
if not kernel.is_generic:
raise RuntimeError(f"Only generic kernels can be overloaded. Kernel {kernel.key} is not generic")
if isinstance(arg_types, list):
arg_list = arg_types
elif isinstance(arg_types, dict):
# substitute named args
arg_list = [a.type for a in kernel.adj.args]
for arg_name, arg_type in arg_types.items():
idx = kernel.arg_indices.get(arg_name)
if idx is None:
raise RuntimeError(f"Invalid argument name '{arg_name}' in overload of kernel {kernel.key}")
arg_list[idx] = arg_type
elif arg_types is None:
arg_list = []
else:
raise TypeError("Kernel overload types must be given in a list or dict")
# return new kernel overload
return kernel.add_overload(arg_list)
elif isinstance(kernel, types.FunctionType):
# handle cases where user calls us as a function decorator (@wp.overload)
# ensure this function name corresponds to a kernel
fn = kernel
module = get_module(fn.__module__)
kernel = module.kernels.get(fn.__name__)
if kernel is None:
raise RuntimeError(f"Failed to find a kernel named '{fn.__name__}' in module {fn.__module__}")
if not kernel.is_generic:
raise RuntimeError(f"Only generic kernels can be overloaded. Kernel {kernel.key} is not generic")
# ensure the function is defined without a body, only ellipsis (...), pass, or a string expression
# TODO: show we allow defining a new body for kernel overloads?
source = inspect.getsource(fn)
tree = ast.parse(source)
assert isinstance(tree, ast.Module)
assert isinstance(tree.body[0], ast.FunctionDef)
func_body = tree.body[0].body
for node in func_body:
if isinstance(node, ast.Pass):
continue
elif isinstance(node, ast.Expr) and isinstance(node.value, (ast.Str, ast.Ellipsis)):
continue
raise RuntimeError(
"Illegal statement in kernel overload definition. Only pass, ellipsis (...), comments, or docstrings are allowed"
)
# ensure all arguments are annotated
argspec = inspect.getfullargspec(fn)
if len(argspec.annotations) < len(argspec.args):
raise RuntimeError(f"Incomplete argument annotations on kernel overload {fn.__name__}")
# get type annotation list
arg_list = []
for arg_name, arg_type in argspec.annotations.items():
if arg_name != "return":
arg_list.append(arg_type)
# add new overload, but we must return the original kernel from @wp.overload decorator!
kernel.add_overload(arg_list)
return kernel
else:
raise RuntimeError("wp.overload() called with invalid argument!")
builtin_functions = {}
def add_builtin(
key,
input_types={},
value_type=None,
value_func=None,
template_func=None,
doc="",
namespace="wp::",
variadic=False,
initializer_list_func=None,
export=True,
group="Other",
hidden=False,
skip_replay=False,
missing_grad=False,
native_func=None,
defaults=None,
require_original_output_arg=False,
):
# wrap simple single-type functions with a value_func()
if value_func is None:
def value_func(args, kwds, templates):
return value_type
if initializer_list_func is None:
def initializer_list_func(args, templates):
return False
if defaults is None:
defaults = {}
# Add specialized versions of this builtin if it's generic by matching arguments against
# hard coded types. We do this so you can use hard coded warp types outside kernels:
generic = any(warp.types.type_is_generic(x) for x in input_types.values())
if generic and export:
# get a list of existing generic vector types (includes matrices and stuff)
# so we can match arguments against them:
generic_vtypes = [x for x in warp.types.vector_types if hasattr(x, "_wp_generic_type_str_")]
# deduplicate identical types:
def typekey(t):
return f"{t._wp_generic_type_str_}_{t._wp_type_params_}"
typedict = {typekey(t): t for t in generic_vtypes}
generic_vtypes = [typedict[k] for k in sorted(typedict.keys())]
# collect the parent type names of all the generic arguments:
def generic_names(l):
for t in l:
if hasattr(t, "_wp_generic_type_str_"):
yield t._wp_generic_type_str_
elif warp.types.type_is_generic_scalar(t):
yield t.__name__
genericset = set(generic_names(input_types.values()))
# for each of those type names, get a list of all hard coded types derived
# from them:
def derived(name):
if name == "Float":
return warp.types.float_types
elif name == "Scalar":
return warp.types.scalar_types
elif name == "Int":
return warp.types.int_types
return [x for x in generic_vtypes if x._wp_generic_type_str_ == name]
gtypes = {k: derived(k) for k in genericset}
# find the scalar data types supported by all the arguments by intersecting
# sets:
def scalar_type(t):
if t in warp.types.scalar_types:
return t
return [p for p in t._wp_type_params_ if p in warp.types.scalar_types][0]
scalartypes = [{scalar_type(x) for x in gtypes[k]} for k in gtypes.keys()]
if scalartypes:
scalartypes = scalartypes.pop().intersection(*scalartypes)
scalartypes = list(scalartypes)
scalartypes.sort(key=str)
# generate function calls for each of these scalar types:
for stype in scalartypes:
# find concrete types for this scalar type (eg if the scalar type is float32
# this dict will look something like this:
# {"vec":[wp.vec2,wp.vec3,wp.vec4], "mat":[wp.mat22,wp.mat33,wp.mat44]})
consistenttypes = {k: [x for x in v if scalar_type(x) == stype] for k, v in gtypes.items()}
def typelist(param):
if warp.types.type_is_generic_scalar(param):
return [stype]
if hasattr(param, "_wp_generic_type_str_"):
l = consistenttypes[param._wp_generic_type_str_]
return [x for x in l if warp.types.types_equal(param, x, match_generic=True)]
return [param]
# gotta try generating function calls for all combinations of these argument types
# now.
import itertools
typelists = [typelist(param) for param in input_types.values()]
for argtypes in itertools.product(*typelists):
# Some of these argument lists won't work, eg if the function is mul(), we won't be
# able to do a matrix vector multiplication for a mat22 and a vec3, so we call value_func
# on the generated argument list and skip generation if it fails.
# This also gives us the return type, which we keep for later:
try:
return_type = value_func(argtypes, {}, [])
except Exception:
continue
# The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
# in the list of hard coded types so it knows it's returning one of them:
if hasattr(return_type, "_wp_generic_type_str_"):
return_type_match = [
x
for x in generic_vtypes
if x._wp_generic_type_str_ == return_type._wp_generic_type_str_
and x._wp_type_params_ == return_type._wp_type_params_
]
if not return_type_match:
continue
return_type = return_type_match[0]
# finally we can generate a function call for these concrete types:
add_builtin(
key,
input_types=dict(zip(input_types.keys(), argtypes)),
value_type=return_type,
doc=doc,
namespace=namespace,
variadic=variadic,
initializer_list_func=initializer_list_func,
export=export,
group=group,
hidden=True,
skip_replay=skip_replay,
missing_grad=missing_grad,
require_original_output_arg=require_original_output_arg,
)
func = Function(
func=None,
key=key,
namespace=namespace,
input_types=input_types,
value_func=value_func,
template_func=template_func,
variadic=variadic,
initializer_list_func=initializer_list_func,
export=export,
doc=doc,
group=group,
hidden=hidden,
skip_replay=skip_replay,
missing_grad=missing_grad,
generic=generic,
native_func=native_func,
defaults=defaults,
require_original_output_arg=require_original_output_arg,
)
if key in builtin_functions:
builtin_functions[key].add_overload(func)
else:
builtin_functions[key] = func
# export means the function will be added to the `warp` module namespace
# so that users can call it directly from the Python interpreter
if export:
if hasattr(warp, key):
# check that we haven't already created something at this location
# if it's just an overload stub for auto-complete then overwrite it
if getattr(warp, key).__name__ != "_overload_dummy":
raise RuntimeError(
f"Trying to register builtin function '{key}' that would overwrite existing object."
)
setattr(warp, key, func)
# global dictionary of modules
user_modules = {}
def get_module(name):
# some modules might be manually imported using `importlib` without being
# registered into `sys.modules`
parent = sys.modules.get(name, None)
parent_loader = None if parent is None else parent.__loader__
if name in user_modules:
# check if the Warp module was created using a different loader object
# if so, we assume the file has changed and we recreate the module to
# clear out old kernels / functions
if user_modules[name].loader is not parent_loader:
old_module = user_modules[name]
# Unload the old module and recursively unload all of its dependents.
# This ensures that dependent modules will be re-hashed and reloaded on next launch.
# The visited set tracks modules already visited to avoid circular references.
def unload_recursive(module, visited):
module.unload()
visited.add(module)
for d in module.dependents:
if d not in visited:
unload_recursive(d, visited)
unload_recursive(old_module, visited=set())
# clear out old kernels, funcs, struct definitions
old_module.kernels = {}
old_module.functions = {}
old_module.constants = []
old_module.structs = {}
old_module.loader = parent_loader
return user_modules[name]
else:
# else Warp module didn't exist yet, so create a new one
user_modules[name] = warp.context.Module(name, parent_loader)
return user_modules[name]
class ModuleBuilder:
def __init__(self, module, options):
self.functions = {}
self.structs = {}
self.options = options
self.module = module
# build all functions declared in the module
for func in module.functions.values():
for f in func.user_overloads.values():
self.build_function(f)
if f.custom_replay_func is not None:
self.build_function(f.custom_replay_func)
# build all kernel entry points
for kernel in module.kernels.values():
if not kernel.is_generic:
self.build_kernel(kernel)
else:
for k in kernel.overloads.values():
self.build_kernel(k)
def build_struct_recursive(self, struct: warp.codegen.Struct):
structs = []
stack = [struct]
while stack:
s = stack.pop()
structs.append(s)
for var in s.vars.values():
if isinstance(var.type, warp.codegen.Struct):
stack.append(var.type)
elif isinstance(var.type, warp.types.array) and isinstance(var.type.dtype, warp.codegen.Struct):
stack.append(var.type.dtype)
# Build them in reverse to generate a correct dependency order.
for s in reversed(structs):
self.build_struct(s)
def build_struct(self, struct):
self.structs[struct] = None
def build_kernel(self, kernel):
kernel.adj.build(self)
if kernel.adj.return_var is not None:
if kernel.adj.return_var.ctype() != "void":
raise TypeError(f"Error, kernels can't have return values, got: {kernel.adj.return_var}")
def build_function(self, func):
if func in self.functions:
return
else:
func.adj.build(self)
# complete the function return type after we have analyzed it (inferred from return statement in ast)
if not func.value_func:
def wrap(adj):
def value_type(arg_types, kwds, templates):
if adj.return_var is None or len(adj.return_var) == 0:
return None
if len(adj.return_var) == 1:
return adj.return_var[0].type
else:
return [v.type for v in adj.return_var]
return value_type
func.value_func = wrap(func.adj)
# use dict to preserve import order
self.functions[func] = None
def codegen(self, device):
source = ""
# code-gen structs
for struct in self.structs.keys():
source += warp.codegen.codegen_struct(struct)
# code-gen all imported functions
for func in self.functions.keys():
if func.native_snippet is None:
source += warp.codegen.codegen_func(
func.adj, c_func_name=func.native_func, device=device, options=self.options
)
else:
source += warp.codegen.codegen_snippet(
func.adj, name=func.key, snippet=func.native_snippet, adj_snippet=func.adj_native_snippet
)
for kernel in self.module.kernels.values():
# each kernel gets an entry point in the module
if not kernel.is_generic:
source += warp.codegen.codegen_kernel(kernel, device=device, options=self.options)
source += warp.codegen.codegen_module(kernel, device=device)
else:
for k in kernel.overloads.values():
source += warp.codegen.codegen_kernel(k, device=device, options=self.options)
source += warp.codegen.codegen_module(k, device=device)
# add headers
if device == "cpu":
source = warp.codegen.cpu_module_header + source
else:
source = warp.codegen.cuda_module_header + source
return source
# -----------------------------------------------------
# stores all functions and kernels for a Python module
# creates a hash of the function to use for checking
# build cache
class Module:
def __init__(self, name, loader):
self.name = name
self.loader = loader
self.kernels = {}
self.functions = {}
self.constants = []
self.structs = {}
self.cpu_module = None
self.cuda_modules = {} # module lookup by CUDA context
self.cpu_build_failed = False
self.cuda_build_failed = False
self.options = {
"max_unroll": 16,
"enable_backward": warp.config.enable_backward,
"fast_math": False,
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
"mode": warp.config.mode,
}
# kernel hook lookup per device
# hooks are stored with the module so they can be easily cleared when the module is reloaded.
# -> See ``Module.get_kernel_hooks()``
self.kernel_hooks = {}
# Module dependencies are determined by scanning each function
# and kernel for references to external functions and structs.
#
# When a referenced module is modified, all of its dependents need to be reloaded
# on the next launch. To detect this, a module's hash recursively includes
# all of its references.
# -> See ``Module.hash_module()``
#
# The dependency mechanism works for both static and dynamic (runtime) modifications.
# When a module is reloaded at runtime, we recursively unload all of its
# dependents, so that they will be re-hashed and reloaded on the next launch.
# -> See ``get_module()``
self.references = set() # modules whose content we depend on
self.dependents = set() # modules that depend on our content
# Since module hashing is recursive, we improve performance by caching the hash of the
# module contents (kernel source, function source, and struct source).
# After all kernels, functions, and structs are added to the module (usually at import time),
# the content hash doesn't change.
# -> See ``Module.hash_module_recursive()``
self.content_hash = None
# number of times module auto-generates kernel key for user
# used to ensure unique kernel keys
self.count = 0
def register_struct(self, struct):
self.structs[struct.key] = struct
# for a reload of module on next launch
self.unload()
def register_kernel(self, kernel):
self.kernels[kernel.key] = kernel
self.find_references(kernel.adj)
# for a reload of module on next launch
self.unload()
def register_function(self, func, skip_adding_overload=False):
if func.key not in self.functions:
self.functions[func.key] = func
else:
# Check whether the new function's signature match any that has
# already been registered. If so, then we simply override it, as
# Python would do it, otherwise we register it as a new overload.
func_existing = self.functions[func.key]
sig = warp.types.get_signature(
func.input_types.values(),
func_name=func.key,
arg_names=list(func.input_types.keys()),
)
sig_existing = warp.types.get_signature(
func_existing.input_types.values(),
func_name=func_existing.key,
arg_names=list(func_existing.input_types.keys()),
)
if sig == sig_existing:
self.functions[func.key] = func
elif not skip_adding_overload:
func_existing.add_overload(func)
self.find_references(func.adj)
# for a reload of module on next launch
self.unload()
def generate_unique_kernel_key(self, key):
unique_key = f"{key}_{self.count}"
self.count += 1
return unique_key
# collect all referenced functions / structs
# given the AST of a function or kernel
def find_references(self, adj):
def add_ref(ref):
if ref is not self:
self.references.add(ref)
ref.dependents.add(self)
# scan for function calls
for node in ast.walk(adj.tree):
if isinstance(node, ast.Call):
try:
# try to resolve the function
func, _ = adj.resolve_static_expression(node.func, eval_types=False)
# if this is a user-defined function, add a module reference
if isinstance(func, warp.context.Function) and func.module is not None:
add_ref(func.module)
except Exception:
# Lookups may fail for builtins, but that's ok.
# Lookups may also fail for functions in this module that haven't been imported yet,
# and that's ok too (not an external reference).
pass
# scan for structs
for arg in adj.args:
if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
add_ref(arg.type.module)
def hash_module(self):
def get_annotations(obj: Any) -> Mapping[str, Any]:
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
if isinstance(obj, type):
return obj.__dict__.get("__annotations__", {})
return getattr(obj, "__annotations__", {})
def get_type_name(type_hint):
if isinstance(type_hint, warp.codegen.Struct):
return get_type_name(type_hint.cls)
return type_hint
def hash_recursive(module, visited):
# Hash this module, including all referenced modules recursively.
# The visited set tracks modules already visited to avoid circular references.
# check if we need to update the content hash
if not module.content_hash:
# recompute content hash
ch = hashlib.sha256()
# struct source
for struct in module.structs.values():
s = ",".join(
"{}: {}".format(name, get_type_name(type_hint))
for name, type_hint in get_annotations(struct.cls).items()
)
ch.update(bytes(s, "utf-8"))
# functions source
for func in module.functions.values():
s = func.adj.source
ch.update(bytes(s, "utf-8"))
if func.custom_grad_func:
s = func.custom_grad_func.adj.source
ch.update(bytes(s, "utf-8"))
if func.custom_replay_func:
s = func.custom_replay_func.adj.source
# cache func arg types
for arg, arg_type in func.adj.arg_types.items():
s = f"{arg}: {get_type_name(arg_type)}"
ch.update(bytes(s, "utf-8"))
# kernel source
for kernel in module.kernels.values():
ch.update(bytes(kernel.adj.source, "utf-8"))
# cache kernel arg types
for arg, arg_type in kernel.adj.arg_types.items():
s = f"{arg}: {get_type_name(arg_type)}"
ch.update(bytes(s, "utf-8"))
# for generic kernels the Python source is always the same,
# but we hash the type signatures of all the overloads
if kernel.is_generic:
for sig in sorted(kernel.overloads.keys()):
ch.update(bytes(sig, "utf-8"))
module.content_hash = ch.digest()
h = hashlib.sha256()
# content hash
h.update(module.content_hash)
# configuration parameters
for k in sorted(module.options.keys()):
s = f"{k}={module.options[k]}"
h.update(bytes(s, "utf-8"))
# ensure to trigger recompilation if flags affecting kernel compilation are changed
if warp.config.verify_fp:
h.update(bytes("verify_fp", "utf-8"))
h.update(bytes(warp.config.mode, "utf-8"))
# compile-time constants (global)
if warp.types._constant_hash:
h.update(warp.types._constant_hash.digest())
# recurse on references
visited.add(module)
sorted_deps = sorted(module.references, key=lambda m: m.name)
for dep in sorted_deps:
if dep not in visited:
dep_hash = hash_recursive(dep, visited)
h.update(dep_hash)
return h.digest()
return hash_recursive(self, visited=set())
def load(self, device):
from warp.utils import ScopedTimer
device = get_device(device)
if device.is_cpu:
# check if already loaded
if self.cpu_module:
return True
# avoid repeated build attempts
if self.cpu_build_failed:
return False
if not warp.is_cpu_available():
raise RuntimeError("Failed to build CPU module because no CPU buildchain was found")
else:
# check if already loaded
if device.context in self.cuda_modules:
return True
# avoid repeated build attempts
if self.cuda_build_failed:
return False
if not warp.is_cuda_available():
raise RuntimeError("Failed to build CUDA module because CUDA is not available")
with ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet):
build_path = warp.build.kernel_bin_dir
gen_path = warp.build.kernel_gen_dir
if not os.path.exists(build_path):
os.makedirs(build_path)
if not os.path.exists(gen_path):
os.makedirs(gen_path)
module_name = "wp_" + self.name
module_path = os.path.join(build_path, module_name)
module_hash = self.hash_module()
builder = ModuleBuilder(self, self.options)
if device.is_cpu:
obj_path = os.path.join(build_path, module_name)
obj_path = obj_path + ".o"
cpu_hash_path = module_path + ".cpu.hash"
# check cache
if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(obj_path):
with open(cpu_hash_path, "rb") as f:
cache_hash = f.read()
if cache_hash == module_hash:
runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
self.cpu_module = module_name
return True
# build
try:
cpp_path = os.path.join(gen_path, module_name + ".cpp")
# write cpp sources
cpp_source = builder.codegen("cpu")
cpp_file = open(cpp_path, "w")
cpp_file.write(cpp_source)
cpp_file.close()
# build object code
with ScopedTimer("Compile x86", active=warp.config.verbose):
warp.build.build_cpu(
obj_path,
cpp_path,
mode=self.options["mode"],
fast_math=self.options["fast_math"],
verify_fp=warp.config.verify_fp,
)
# update cpu hash
with open(cpu_hash_path, "wb") as f:
f.write(module_hash)
# load the object code
runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
self.cpu_module = module_name
except Exception as e:
self.cpu_build_failed = True
raise (e)
elif device.is_cuda:
# determine whether to use PTX or CUBIN
if device.is_cubin_supported:
# get user preference specified either per module or globally
preferred_cuda_output = self.options.get("cuda_output") or warp.config.cuda_output
if preferred_cuda_output is not None:
use_ptx = preferred_cuda_output == "ptx"
else:
# determine automatically: older drivers may not be able to handle PTX generated using newer
# CUDA Toolkits, in which case we fall back on generating CUBIN modules
use_ptx = runtime.driver_version >= runtime.toolkit_version
else:
# CUBIN not an option, must use PTX (e.g. CUDA Toolkit too old)
use_ptx = True
if use_ptx:
output_arch = min(device.arch, warp.config.ptx_target_arch)
output_path = module_path + f".sm{output_arch}.ptx"
else:
output_arch = device.arch
output_path = module_path + f".sm{output_arch}.cubin"
cuda_hash_path = module_path + f".sm{output_arch}.hash"
# check cache
if warp.config.cache_kernels and os.path.isfile(cuda_hash_path) and os.path.isfile(output_path):
with open(cuda_hash_path, "rb") as f:
cache_hash = f.read()
if cache_hash == module_hash:
cuda_module = warp.build.load_cuda(output_path, device)
if cuda_module is not None:
self.cuda_modules[device.context] = cuda_module
return True
# build
try:
cu_path = os.path.join(gen_path, module_name + ".cu")
# write cuda sources
cu_source = builder.codegen("cuda")
cu_file = open(cu_path, "w")
cu_file.write(cu_source)
cu_file.close()
# generate PTX or CUBIN
with ScopedTimer("Compile CUDA", active=warp.config.verbose):
warp.build.build_cuda(
cu_path,
output_arch,
output_path,
config=self.options["mode"],
fast_math=self.options["fast_math"],
verify_fp=warp.config.verify_fp,
)
# update cuda hash
with open(cuda_hash_path, "wb") as f:
f.write(module_hash)
# load the module
cuda_module = warp.build.load_cuda(output_path, device)
if cuda_module is not None:
self.cuda_modules[device.context] = cuda_module
else:
raise Exception("Failed to load CUDA module")
except Exception as e:
self.cuda_build_failed = True
raise (e)
return True
def unload(self):
if self.cpu_module:
runtime.llvm.unload_obj(self.cpu_module.encode("utf-8"))
self.cpu_module = None
# need to unload the CUDA module from all CUDA contexts where it is loaded
# note: we ensure that this doesn't change the current CUDA context
if self.cuda_modules:
saved_context = runtime.core.cuda_context_get_current()
for context, module in self.cuda_modules.items():
runtime.core.cuda_unload_module(context, module)
runtime.core.cuda_context_set_current(saved_context)
self.cuda_modules = {}
# clear kernel hooks
self.kernel_hooks = {}
# clear content hash
self.content_hash = None
# lookup and cache kernel entry points based on name, called after compilation / module load
def get_kernel_hooks(self, kernel, device):
# get all hooks for this device
device_hooks = self.kernel_hooks.get(device.context)
if device_hooks is None:
self.kernel_hooks[device.context] = device_hooks = {}
# look up this kernel
hooks = device_hooks.get(kernel)
if hooks is not None:
return hooks
name = kernel.get_mangled_name()
if device.is_cpu:
func = ctypes.CFUNCTYPE(None)
forward = func(
runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_forward").encode("utf-8"))
)
backward = func(
runtime.llvm.lookup(self.cpu_module.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))
)
else:
cu_module = self.cuda_modules[device.context]
forward = runtime.core.cuda_get_kernel(
device.context, cu_module, (name + "_cuda_kernel_forward").encode("utf-8")
)
backward = runtime.core.cuda_get_kernel(
device.context, cu_module, (name + "_cuda_kernel_backward").encode("utf-8")
)
hooks = KernelHooks(forward, backward)
device_hooks[kernel] = hooks
return hooks
# -------------------------------------------
# execution context
# a simple allocator
# TODO: use a pooled allocator to avoid hitting the system allocator
class Allocator:
def __init__(self, device):
self.device = device
def alloc(self, size_in_bytes, pinned=False):
if self.device.is_cuda:
if self.device.is_capturing:
raise RuntimeError(f"Cannot allocate memory on device {self} while graph capture is active")
return runtime.core.alloc_device(self.device.context, size_in_bytes)
elif self.device.is_cpu:
if pinned:
return runtime.core.alloc_pinned(size_in_bytes)
else:
return runtime.core.alloc_host(size_in_bytes)
def free(self, ptr, size_in_bytes, pinned=False):
if self.device.is_cuda:
if self.device.is_capturing:
raise RuntimeError(f"Cannot free memory on device {self} while graph capture is active")
return runtime.core.free_device(self.device.context, ptr)
elif self.device.is_cpu:
if pinned:
return runtime.core.free_pinned(ptr)
else:
return runtime.core.free_host(ptr)
class ContextGuard:
def __init__(self, device):
self.device = device
def __enter__(self):
if self.device.is_cuda:
runtime.core.cuda_context_push_current(self.device.context)
elif is_cuda_driver_initialized():
self.saved_context = runtime.core.cuda_context_get_current()
def __exit__(self, exc_type, exc_value, traceback):
if self.device.is_cuda:
runtime.core.cuda_context_pop_current()
elif is_cuda_driver_initialized():
runtime.core.cuda_context_set_current(self.saved_context)
class Stream:
def __init__(self, device=None, **kwargs):
self.owner = False
# we can't use get_device() if called during init, but we can use an explicit Device arg
if runtime is not None:
device = runtime.get_device(device)
elif not isinstance(device, Device):
raise RuntimeError(
"A device object is required when creating a stream before or during Warp initialization"
)
if not device.is_cuda:
raise RuntimeError(f"Device {device} is not a CUDA device")
# we pass cuda_stream through kwargs because cuda_stream=None is actually a valid value (CUDA default stream)
if "cuda_stream" in kwargs:
self.cuda_stream = kwargs["cuda_stream"]
else:
self.cuda_stream = device.runtime.core.cuda_stream_create(device.context)
if not self.cuda_stream:
raise RuntimeError(f"Failed to create stream on device {device}")
self.owner = True
self.device = device
def __del__(self):
if self.owner:
runtime.core.cuda_stream_destroy(self.device.context, self.cuda_stream)
def record_event(self, event=None):
if event is None:
event = Event(self.device)
elif event.device != self.device:
raise RuntimeError(
f"Event from device {event.device} cannot be recorded on stream from device {self.device}"
)
runtime.core.cuda_event_record(self.device.context, event.cuda_event, self.cuda_stream)
return event
def wait_event(self, event):
runtime.core.cuda_stream_wait_event(self.device.context, self.cuda_stream, event.cuda_event)
def wait_stream(self, other_stream, event=None):
if event is None:
event = Event(other_stream.device)
runtime.core.cuda_stream_wait_stream(
self.device.context, self.cuda_stream, other_stream.cuda_stream, event.cuda_event
)
class Event:
# event creation flags
class Flags:
DEFAULT = 0x0
BLOCKING_SYNC = 0x1
DISABLE_TIMING = 0x2
def __init__(self, device=None, cuda_event=None, enable_timing=False):
self.owner = False
device = get_device(device)
if not device.is_cuda:
raise RuntimeError(f"Device {device} is not a CUDA device")
self.device = device
if cuda_event is not None:
self.cuda_event = cuda_event
else:
flags = Event.Flags.DEFAULT
if not enable_timing:
flags |= Event.Flags.DISABLE_TIMING
self.cuda_event = runtime.core.cuda_event_create(device.context, flags)
if not self.cuda_event:
raise RuntimeError(f"Failed to create event on device {device}")
self.owner = True
def __del__(self):
if self.owner:
runtime.core.cuda_event_destroy(self.device.context, self.cuda_event)
class Device:
def __init__(self, runtime, alias, ordinal=-1, is_primary=False, context=None):
self.runtime = runtime
self.alias = alias
self.ordinal = ordinal
self.is_primary = is_primary
# context can be None to avoid acquiring primary contexts until the device is used
self._context = context
# if the device context is not primary, it cannot be None
if ordinal != -1 and not is_primary:
assert context is not None
# streams will be created when context is acquired
self._stream = None
self.null_stream = None
# indicates whether CUDA graph capture is active for this device
self.is_capturing = False
self.allocator = Allocator(self)
self.context_guard = ContextGuard(self)
if self.ordinal == -1:
# CPU device
self.name = platform.processor() or "CPU"
self.arch = 0
self.is_uva = False
self.is_cubin_supported = False
self.is_mempool_supported = False
# TODO: add more device-specific dispatch functions
self.memset = runtime.core.memset_host
self.memtile = runtime.core.memtile_host
elif ordinal >= 0 and ordinal < runtime.core.cuda_device_get_count():
# CUDA device
self.name = runtime.core.cuda_device_get_name(ordinal).decode()
self.arch = runtime.core.cuda_device_get_arch(ordinal)
self.is_uva = runtime.core.cuda_device_is_uva(ordinal)
# check whether our NVRTC can generate CUBINs for this architecture
self.is_cubin_supported = self.arch in runtime.nvrtc_supported_archs
self.is_mempool_supported = runtime.core.cuda_device_is_memory_pool_supported(ordinal)
# Warn the user of a possible misconfiguration of their system
if not self.is_mempool_supported:
warp.utils.warn(
f"Support for stream ordered memory allocators was not detected on device {ordinal}. "
"This can prevent the use of graphs and/or result in poor performance. "
"Is the UVM driver enabled?"
)
# initialize streams unless context acquisition is postponed
if self._context is not None:
self.init_streams()
# TODO: add more device-specific dispatch functions
self.memset = lambda ptr, value, size: runtime.core.memset_device(self.context, ptr, value, size)
self.memtile = lambda ptr, src, srcsize, reps: runtime.core.memtile_device(
self.context, ptr, src, srcsize, reps
)
else:
raise RuntimeError(f"Invalid device ordinal ({ordinal})'")
def init_streams(self):
# create a stream for asynchronous work
self.stream = Stream(self)
# CUDA default stream for some synchronous operations
self.null_stream = Stream(self, cuda_stream=None)
@property
def is_cpu(self):
return self.ordinal < 0
@property
def is_cuda(self):
return self.ordinal >= 0
@property
def context(self):
if self._context is not None:
return self._context
elif self.is_primary:
# acquire primary context on demand
self._context = self.runtime.core.cuda_device_primary_context_retain(self.ordinal)
if self._context is None:
raise RuntimeError(f"Failed to acquire primary context for device {self}")
self.runtime.context_map[self._context] = self
# initialize streams
self.init_streams()
return self._context
@property
def has_context(self):
return self._context is not None
@property
def stream(self):
if self.context:
return self._stream
else:
raise RuntimeError(f"Device {self} is not a CUDA device")
@stream.setter
def stream(self, s):
if self.is_cuda:
if s.device != self:
raise RuntimeError(f"Stream from device {s.device} cannot be used on device {self}")
self._stream = s
self.runtime.core.cuda_context_set_stream(self.context, s.cuda_stream)
else:
raise RuntimeError(f"Device {self} is not a CUDA device")
@property
def has_stream(self):
return self._stream is not None
def __str__(self):
return self.alias
def __repr__(self):
return f"'{self.alias}'"
def __eq__(self, other):
if self is other:
return True
elif isinstance(other, Device):
return self.context == other.context
elif isinstance(other, str):
if other == "cuda":
return self == self.runtime.get_current_cuda_device()
else:
return other == self.alias
else:
return False
def make_current(self):
if self.context is not None:
self.runtime.core.cuda_context_set_current(self.context)
def can_access(self, other):
other = self.runtime.get_device(other)
if self.context == other.context:
return True
elif self.context is not None and other.context is not None:
return bool(self.runtime.core.cuda_context_can_access_peer(self.context, other.context))
else:
return False
""" Meta-type for arguments that can be resolved to a concrete Device.
"""
Devicelike = Union[Device, str, None]
class Graph:
def __init__(self, device: Device, exec: ctypes.c_void_p):
self.device = device
self.exec = exec
def __del__(self):
# use CUDA context guard to avoid side effects during garbage collection
with self.device.context_guard:
runtime.core.cuda_graph_destroy(self.device.context, self.exec)
class Runtime:
def __init__(self):
bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
if os.name == "nt":
if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8:
# Python >= 3.8 this method to add dll search paths
os.add_dll_directory(bin_path)
else:
# Python < 3.8 we add dll directory to path
os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
warp_lib = os.path.join(bin_path, "warp.dll")
llvm_lib = os.path.join(bin_path, "warp-clang.dll")
elif sys.platform == "darwin":
warp_lib = os.path.join(bin_path, "libwarp.dylib")
llvm_lib = os.path.join(bin_path, "libwarp-clang.dylib")
else:
warp_lib = os.path.join(bin_path, "warp.so")
llvm_lib = os.path.join(bin_path, "warp-clang.so")
self.core = self.load_dll(warp_lib)
if os.path.exists(llvm_lib):
self.llvm = self.load_dll(llvm_lib)
# setup c-types for warp-clang.dll
self.llvm.lookup.restype = ctypes.c_uint64
else:
self.llvm = None
# setup c-types for warp.dll
self.core.alloc_host.argtypes = [ctypes.c_size_t]
self.core.alloc_host.restype = ctypes.c_void_p
self.core.alloc_pinned.argtypes = [ctypes.c_size_t]
self.core.alloc_pinned.restype = ctypes.c_void_p
self.core.alloc_device.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
self.core.alloc_device.restype = ctypes.c_void_p
self.core.float_to_half_bits.argtypes = [ctypes.c_float]
self.core.float_to_half_bits.restype = ctypes.c_uint16
self.core.half_bits_to_float.argtypes = [ctypes.c_uint16]
self.core.half_bits_to_float.restype = ctypes.c_float
self.core.free_host.argtypes = [ctypes.c_void_p]
self.core.free_host.restype = None
self.core.free_pinned.argtypes = [ctypes.c_void_p]
self.core.free_pinned.restype = None
self.core.free_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.free_device.restype = None
self.core.memset_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
self.core.memset_host.restype = None
self.core.memset_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
self.core.memset_device.restype = None
self.core.memtile_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t]
self.core.memtile_host.restype = None
self.core.memtile_device.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_size_t,
]
self.core.memtile_device.restype = None
self.core.memcpy_h2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
self.core.memcpy_h2h.restype = None
self.core.memcpy_h2d.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
self.core.memcpy_h2d.restype = None
self.core.memcpy_d2h.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
self.core.memcpy_d2h.restype = None
self.core.memcpy_d2d.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
self.core.memcpy_d2d.restype = None
self.core.memcpy_peer.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t]
self.core.memcpy_peer.restype = None
self.core.array_copy_host.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_copy_host.restype = ctypes.c_size_t
self.core.array_copy_device.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_copy_device.restype = ctypes.c_size_t
self.core.array_fill_host.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p, ctypes.c_int]
self.core.array_fill_host.restype = None
self.core.array_fill_device.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_void_p,
ctypes.c_int,
]
self.core.array_fill_device.restype = None
self.core.array_sum_double_host.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_sum_float_host.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_sum_double_device.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_sum_float_device.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_inner_double_host.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_inner_float_host.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_inner_double_device.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_inner_float_device.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.array_scan_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
self.core.array_scan_float_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
self.core.array_scan_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
self.core.array_scan_float_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int, ctypes.c_bool]
self.core.radix_sort_pairs_int_host.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
self.core.radix_sort_pairs_int_device.argtypes = [ctypes.c_uint64, ctypes.c_uint64, ctypes.c_int]
self.core.runlength_encode_int_host.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
]
self.core.runlength_encode_int_device.argtypes = [
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_int,
]
self.core.bvh_create_host.restype = ctypes.c_uint64
self.core.bvh_create_host.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
self.core.bvh_create_device.restype = ctypes.c_uint64
self.core.bvh_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int]
self.core.bvh_destroy_host.argtypes = [ctypes.c_uint64]
self.core.bvh_destroy_device.argtypes = [ctypes.c_uint64]
self.core.bvh_refit_host.argtypes = [ctypes.c_uint64]
self.core.bvh_refit_device.argtypes = [ctypes.c_uint64]
self.core.mesh_create_host.restype = ctypes.c_uint64
self.core.mesh_create_host.argtypes = [
warp.types.array_t,
warp.types.array_t,
warp.types.array_t,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.mesh_create_device.restype = ctypes.c_uint64
self.core.mesh_create_device.argtypes = [
ctypes.c_void_p,
warp.types.array_t,
warp.types.array_t,
warp.types.array_t,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
]
self.core.mesh_destroy_host.argtypes = [ctypes.c_uint64]
self.core.mesh_destroy_device.argtypes = [ctypes.c_uint64]
self.core.mesh_refit_host.argtypes = [ctypes.c_uint64]
self.core.mesh_refit_device.argtypes = [ctypes.c_uint64]
self.core.hash_grid_create_host.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.core.hash_grid_create_host.restype = ctypes.c_uint64
self.core.hash_grid_destroy_host.argtypes = [ctypes.c_uint64]
self.core.hash_grid_update_host.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p, ctypes.c_int]
self.core.hash_grid_reserve_host.argtypes = [ctypes.c_uint64, ctypes.c_int]
self.core.hash_grid_create_device.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.core.hash_grid_create_device.restype = ctypes.c_uint64
self.core.hash_grid_destroy_device.argtypes = [ctypes.c_uint64]
self.core.hash_grid_update_device.argtypes = [ctypes.c_uint64, ctypes.c_float, ctypes.c_void_p, ctypes.c_int]
self.core.hash_grid_reserve_device.argtypes = [ctypes.c_uint64, ctypes.c_int]
self.core.cutlass_gemm.argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_char_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_float,
ctypes.c_float,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_int,
]
self.core.cutlass_gemm.restypes = ctypes.c_bool
self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64]
self.core.volume_create_host.restype = ctypes.c_uint64
self.core.volume_get_buffer_info_host.argtypes = [
ctypes.c_uint64,
ctypes.POINTER(ctypes.c_void_p),
ctypes.POINTER(ctypes.c_uint64),
]
self.core.volume_get_tiles_host.argtypes = [
ctypes.c_uint64,
ctypes.POINTER(ctypes.c_void_p),
ctypes.POINTER(ctypes.c_uint64),
]
self.core.volume_destroy_host.argtypes = [ctypes.c_uint64]
self.core.volume_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64]
self.core.volume_create_device.restype = ctypes.c_uint64
self.core.volume_f_from_tiles_device.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_bool,
]
self.core.volume_f_from_tiles_device.restype = ctypes.c_uint64
self.core.volume_v_from_tiles_device.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_bool,
]
self.core.volume_v_from_tiles_device.restype = ctypes.c_uint64
self.core.volume_i_from_tiles_device.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_float,
ctypes.c_int,
ctypes.c_float,
ctypes.c_float,
ctypes.c_float,
ctypes.c_bool,
]
self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
self.core.volume_get_buffer_info_device.argtypes = [
ctypes.c_uint64,
ctypes.POINTER(ctypes.c_void_p),
ctypes.POINTER(ctypes.c_uint64),
]
self.core.volume_get_tiles_device.argtypes = [
ctypes.c_uint64,
ctypes.POINTER(ctypes.c_void_p),
ctypes.POINTER(ctypes.c_uint64),
]
self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
self.core.volume_get_voxel_size.argtypes = [
ctypes.c_uint64,
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
ctypes.POINTER(ctypes.c_float),
]
bsr_matrix_from_triplets_argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
]
self.core.bsr_matrix_from_triplets_float_host.argtypes = bsr_matrix_from_triplets_argtypes
self.core.bsr_matrix_from_triplets_double_host.argtypes = bsr_matrix_from_triplets_argtypes
self.core.bsr_matrix_from_triplets_float_device.argtypes = bsr_matrix_from_triplets_argtypes
self.core.bsr_matrix_from_triplets_double_device.argtypes = bsr_matrix_from_triplets_argtypes
self.core.bsr_matrix_from_triplets_float_host.restype = ctypes.c_int
self.core.bsr_matrix_from_triplets_double_host.restype = ctypes.c_int
self.core.bsr_matrix_from_triplets_float_device.restype = ctypes.c_int
self.core.bsr_matrix_from_triplets_double_device.restype = ctypes.c_int
bsr_transpose_argtypes = [
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
ctypes.c_uint64,
]
self.core.bsr_transpose_float_host.argtypes = bsr_transpose_argtypes
self.core.bsr_transpose_double_host.argtypes = bsr_transpose_argtypes
self.core.bsr_transpose_float_device.argtypes = bsr_transpose_argtypes
self.core.bsr_transpose_double_device.argtypes = bsr_transpose_argtypes
self.core.is_cuda_enabled.argtypes = None
self.core.is_cuda_enabled.restype = ctypes.c_int
self.core.is_cuda_compatibility_enabled.argtypes = None
self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
self.core.is_cutlass_enabled.argtypes = None
self.core.is_cutlass_enabled.restype = ctypes.c_int
self.core.cuda_driver_version.argtypes = None
self.core.cuda_driver_version.restype = ctypes.c_int
self.core.cuda_toolkit_version.argtypes = None
self.core.cuda_toolkit_version.restype = ctypes.c_int
self.core.cuda_driver_is_initialized.argtypes = None
self.core.cuda_driver_is_initialized.restype = ctypes.c_bool
self.core.nvrtc_supported_arch_count.argtypes = None
self.core.nvrtc_supported_arch_count.restype = ctypes.c_int
self.core.nvrtc_supported_archs.argtypes = [ctypes.POINTER(ctypes.c_int)]
self.core.nvrtc_supported_archs.restype = None
self.core.cuda_device_get_count.argtypes = None
self.core.cuda_device_get_count.restype = ctypes.c_int
self.core.cuda_device_primary_context_retain.argtypes = [ctypes.c_int]
self.core.cuda_device_primary_context_retain.restype = ctypes.c_void_p
self.core.cuda_device_get_name.argtypes = [ctypes.c_int]
self.core.cuda_device_get_name.restype = ctypes.c_char_p
self.core.cuda_device_get_arch.argtypes = [ctypes.c_int]
self.core.cuda_device_get_arch.restype = ctypes.c_int
self.core.cuda_device_is_uva.argtypes = [ctypes.c_int]
self.core.cuda_device_is_uva.restype = ctypes.c_int
self.core.cuda_context_get_current.argtypes = None
self.core.cuda_context_get_current.restype = ctypes.c_void_p
self.core.cuda_context_set_current.argtypes = [ctypes.c_void_p]
self.core.cuda_context_set_current.restype = None
self.core.cuda_context_push_current.argtypes = [ctypes.c_void_p]
self.core.cuda_context_push_current.restype = None
self.core.cuda_context_pop_current.argtypes = None
self.core.cuda_context_pop_current.restype = None
self.core.cuda_context_create.argtypes = [ctypes.c_int]
self.core.cuda_context_create.restype = ctypes.c_void_p
self.core.cuda_context_destroy.argtypes = [ctypes.c_void_p]
self.core.cuda_context_destroy.restype = None
self.core.cuda_context_synchronize.argtypes = [ctypes.c_void_p]
self.core.cuda_context_synchronize.restype = None
self.core.cuda_context_check.argtypes = [ctypes.c_void_p]
self.core.cuda_context_check.restype = ctypes.c_uint64
self.core.cuda_context_get_device_ordinal.argtypes = [ctypes.c_void_p]
self.core.cuda_context_get_device_ordinal.restype = ctypes.c_int
self.core.cuda_context_is_primary.argtypes = [ctypes.c_void_p]
self.core.cuda_context_is_primary.restype = ctypes.c_int
self.core.cuda_context_get_stream.argtypes = [ctypes.c_void_p]
self.core.cuda_context_get_stream.restype = ctypes.c_void_p
self.core.cuda_context_set_stream.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_context_set_stream.restype = None
self.core.cuda_context_can_access_peer.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_context_can_access_peer.restype = ctypes.c_int
self.core.cuda_stream_create.argtypes = [ctypes.c_void_p]
self.core.cuda_stream_create.restype = ctypes.c_void_p
self.core.cuda_stream_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_stream_destroy.restype = None
self.core.cuda_stream_synchronize.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_stream_synchronize.restype = None
self.core.cuda_stream_wait_event.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_stream_wait_event.restype = None
self.core.cuda_stream_wait_stream.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
]
self.core.cuda_stream_wait_stream.restype = None
self.core.cuda_event_create.argtypes = [ctypes.c_void_p, ctypes.c_uint]
self.core.cuda_event_create.restype = ctypes.c_void_p
self.core.cuda_event_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_event_destroy.restype = None
self.core.cuda_event_record.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_event_record.restype = None
self.core.cuda_graph_begin_capture.argtypes = [ctypes.c_void_p]
self.core.cuda_graph_begin_capture.restype = None
self.core.cuda_graph_end_capture.argtypes = [ctypes.c_void_p]
self.core.cuda_graph_end_capture.restype = ctypes.c_void_p
self.core.cuda_graph_launch.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_graph_launch.restype = None
self.core.cuda_graph_destroy.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_graph_destroy.restype = None
self.core.cuda_compile_program.argtypes = [
ctypes.c_char_p,
ctypes.c_int,
ctypes.c_char_p,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_bool,
ctypes.c_char_p,
]
self.core.cuda_compile_program.restype = ctypes.c_size_t
self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
self.core.cuda_load_module.restype = ctypes.c_void_p
self.core.cuda_unload_module.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_unload_module.restype = None
self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
self.core.cuda_get_kernel.restype = ctypes.c_void_p
self.core.cuda_launch_kernel.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_int,
ctypes.POINTER(ctypes.c_void_p),
]
self.core.cuda_launch_kernel.restype = ctypes.c_size_t
self.core.cuda_graphics_map.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_graphics_map.restype = None
self.core.cuda_graphics_unmap.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_graphics_unmap.restype = None
self.core.cuda_graphics_device_ptr_and_size.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_uint64),
ctypes.POINTER(ctypes.c_size_t),
]
self.core.cuda_graphics_device_ptr_and_size.restype = None
self.core.cuda_graphics_register_gl_buffer.argtypes = [ctypes.c_void_p, ctypes.c_uint32, ctypes.c_uint]
self.core.cuda_graphics_register_gl_buffer.restype = ctypes.c_void_p
self.core.cuda_graphics_unregister_resource.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
self.core.cuda_graphics_unregister_resource.restype = None
self.core.init.restype = ctypes.c_int
error = self.core.init()
if error != 0:
raise Exception("Warp initialization failed")
self.device_map = {} # device lookup by alias
self.context_map = {} # device lookup by context
# register CPU device
cpu_name = platform.processor()
if not cpu_name:
cpu_name = "CPU"
self.cpu_device = Device(self, "cpu")
self.device_map["cpu"] = self.cpu_device
self.context_map[None] = self.cpu_device
cuda_device_count = self.core.cuda_device_get_count()
if cuda_device_count > 0:
# get CUDA Toolkit and driver versions
self.toolkit_version = self.core.cuda_toolkit_version()
self.driver_version = self.core.cuda_driver_version()
# get all architectures supported by NVRTC
num_archs = self.core.nvrtc_supported_arch_count()
if num_archs > 0:
archs = (ctypes.c_int * num_archs)()
self.core.nvrtc_supported_archs(archs)
self.nvrtc_supported_archs = list(archs)
else:
self.nvrtc_supported_archs = []
# register CUDA devices
self.cuda_devices = []
self.cuda_primary_devices = []
for i in range(cuda_device_count):
alias = f"cuda:{i}"
device = Device(self, alias, ordinal=i, is_primary=True)
self.cuda_devices.append(device)
self.cuda_primary_devices.append(device)
self.device_map[alias] = device
# set default device
if cuda_device_count > 0:
if self.core.cuda_context_get_current() is not None:
self.set_default_device("cuda")
else:
self.set_default_device("cuda:0")
else:
# CUDA not available
self.set_default_device("cpu")
# initialize kernel cache
warp.build.init_kernel_cache(warp.config.kernel_cache_dir)
# print device and version information
if not warp.config.quiet:
print(f"Warp {warp.config.version} initialized:")
if cuda_device_count > 0:
toolkit_version = (self.toolkit_version // 1000, (self.toolkit_version % 1000) // 10)
driver_version = (self.driver_version // 1000, (self.driver_version % 1000) // 10)
print(
f" CUDA Toolkit: {toolkit_version[0]}.{toolkit_version[1]}, Driver: {driver_version[0]}.{driver_version[1]}"
)
else:
if self.core.is_cuda_enabled():
# Warp was compiled with CUDA support, but no devices are available
print(" CUDA devices not available")
else:
# Warp was compiled without CUDA support
print(" CUDA support not enabled in this build")
print(" Devices:")
print(f' "{self.cpu_device.alias}" | {self.cpu_device.name}')
for cuda_device in self.cuda_devices:
print(f' "{cuda_device.alias}" | {cuda_device.name} (sm_{cuda_device.arch})')
print(f" Kernel cache: {warp.config.kernel_cache_dir}")
# CUDA compatibility check
if cuda_device_count > 0 and not self.core.is_cuda_compatibility_enabled():
if self.driver_version < self.toolkit_version:
print("******************************************************************")
print("* WARNING: *")
print("* Warp was compiled without CUDA compatibility support *")
print("* (quick build). The CUDA Toolkit version used to build *")
print("* Warp is not fully supported by the current driver. *")
print("* Some CUDA functionality may not work correctly! *")
print("* Update the driver or rebuild Warp without the --quick flag. *")
print("******************************************************************")
# global tape
self.tape = None
def load_dll(self, dll_path):
try:
if sys.version_info[0] > 3 or sys.version_info[0] == 3 and sys.version_info[1] >= 8:
dll = ctypes.CDLL(dll_path, winmode=0)
else:
dll = ctypes.CDLL(dll_path)
except OSError as e:
if "GLIBCXX" in str(e):
raise RuntimeError(
f"Failed to load the shared library '{dll_path}'.\n"
"The execution environment's libstdc++ runtime is older than the version the Warp library was built for.\n"
"See https://nvidia.github.io/warp/_build/html/installation.html#conda-environments for details."
) from e
else:
raise RuntimeError(f"Failed to load the shared library '{dll_path}'") from e
return dll
def get_device(self, ident: Devicelike = None) -> Device:
if isinstance(ident, Device):
return ident
elif ident is None:
return self.default_device
elif isinstance(ident, str):
if ident == "cuda":
return self.get_current_cuda_device()
else:
return self.device_map[ident]
else:
raise RuntimeError(f"Unable to resolve device from argument of type {type(ident)}")
def set_default_device(self, ident: Devicelike):
self.default_device = self.get_device(ident)
def get_current_cuda_device(self):
current_context = self.core.cuda_context_get_current()
if current_context is not None:
current_device = self.context_map.get(current_context)
if current_device is not None:
# this is a known device
return current_device
elif self.core.cuda_context_is_primary(current_context):
# this is a primary context that we haven't used yet
ordinal = self.core.cuda_context_get_device_ordinal(current_context)
device = self.cuda_devices[ordinal]
self.context_map[current_context] = device
return device
else:
# this is an unseen non-primary context, register it as a new device with a unique alias
alias = f"cuda!{current_context:x}"
return self.map_cuda_device(alias, current_context)
elif self.default_device.is_cuda:
return self.default_device
elif self.cuda_devices:
return self.cuda_devices[0]
else:
raise RuntimeError("CUDA is not available")
def rename_device(self, device, alias):
del self.device_map[device.alias]
device.alias = alias
self.device_map[alias] = device
return device
def map_cuda_device(self, alias, context=None) -> Device:
if context is None:
context = self.core.cuda_context_get_current()
if context is None:
raise RuntimeError(f"Unable to determine CUDA context for device alias '{alias}'")
# check if this alias already exists
if alias in self.device_map:
device = self.device_map[alias]
if context == device.context:
# device already exists with the same alias, that's fine
return device
else:
raise RuntimeError(f"Device alias '{alias}' already exists")
# check if this context already has an associated Warp device
if context in self.context_map:
# rename the device
device = self.context_map[context]
return self.rename_device(device, alias)
else:
# it's an unmapped context
# get the device ordinal
ordinal = self.core.cuda_context_get_device_ordinal(context)
# check if this is a primary context (we could get here if it's a device that hasn't been used yet)
if self.core.cuda_context_is_primary(context):
# rename the device
device = self.cuda_primary_devices[ordinal]
return self.rename_device(device, alias)
else:
# create a new Warp device for this context
device = Device(self, alias, ordinal=ordinal, is_primary=False, context=context)
self.device_map[alias] = device
self.context_map[context] = device
self.cuda_devices.append(device)
return device
def unmap_cuda_device(self, alias):
device = self.device_map.get(alias)
# make sure the alias refers to a CUDA device
if device is None or not device.is_cuda:
raise RuntimeError(f"Invalid CUDA device alias '{alias}'")
del self.device_map[alias]
del self.context_map[device.context]
self.cuda_devices.remove(device)
def verify_cuda_device(self, device: Devicelike = None):
if warp.config.verify_cuda:
device = runtime.get_device(device)
if not device.is_cuda:
return
err = self.core.cuda_context_check(device.context)
if err != 0:
raise RuntimeError(f"CUDA error detected: {err}")
def assert_initialized():
assert runtime is not None, "Warp not initialized, call wp.init() before use"
# global entry points
def is_cpu_available():
return runtime.llvm
def is_cuda_available():
return get_cuda_device_count() > 0
def is_device_available(device):
return device in get_devices()
def is_cuda_driver_initialized() -> bool:
"""Returns ``True`` if the CUDA driver is initialized.
This is a stricter test than ``is_cuda_available()`` since a CUDA driver
call to ``cuCtxGetCurrent`` is made, and the result is compared to
`CUDA_SUCCESS`. Note that `CUDA_SUCCESS` is returned by ``cuCtxGetCurrent``
even if there is no context bound to the calling CPU thread.
This can be helpful in cases in which ``cuInit()`` was called before a fork.
"""
assert_initialized()
return runtime.core.cuda_driver_is_initialized()
def get_devices() -> List[Device]:
"""Returns a list of devices supported in this environment."""
assert_initialized()
devices = []
if is_cpu_available():
devices.append(runtime.cpu_device)
for cuda_device in runtime.cuda_devices:
devices.append(cuda_device)
return devices
def get_cuda_device_count() -> int:
"""Returns the number of CUDA devices supported in this environment."""
assert_initialized()
return len(runtime.cuda_devices)
def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
"""Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
assert_initialized()
if ordinal is None:
return runtime.get_current_cuda_device()
else:
return runtime.cuda_devices[ordinal]
def get_cuda_devices() -> List[Device]:
"""Returns a list of CUDA devices supported in this environment."""
assert_initialized()
return runtime.cuda_devices
def get_preferred_device() -> Device:
"""Returns the preferred compute device, CUDA if available and CPU otherwise."""
assert_initialized()
if is_cuda_available():
return runtime.cuda_devices[0]
elif is_cpu_available():
return runtime.cpu_device
else:
return None
def get_device(ident: Devicelike = None) -> Device:
"""Returns the device identified by the argument."""
assert_initialized()
return runtime.get_device(ident)
def set_device(ident: Devicelike):
"""Sets the target device identified by the argument."""
assert_initialized()
device = runtime.get_device(ident)
runtime.set_default_device(device)
device.make_current()
def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
"""Assign a device alias to a CUDA context.
This function can be used to create a wp.Device for an external CUDA context.
If a wp.Device already exists for the given context, it's alias will change to the given value.
Args:
alias: A unique string to identify the device.
context: A CUDA context pointer (CUcontext). If None, the currently bound CUDA context will be used.
Returns:
The associated wp.Device.
"""
assert_initialized()
return runtime.map_cuda_device(alias, context)
def unmap_cuda_device(alias: str):
"""Remove a CUDA device with the given alias."""
assert_initialized()
runtime.unmap_cuda_device(alias)
def get_stream(device: Devicelike = None) -> Stream:
"""Return the stream currently used by the given device"""
return get_device(device).stream
def set_stream(stream, device: Devicelike = None):
"""Set the stream to be used by the given device.
If this is an external stream, caller is responsible for guaranteeing the lifetime of the stream.
Consider using wp.ScopedStream instead.
"""
get_device(device).stream = stream
def record_event(event: Event = None):
"""Record a CUDA event on the current stream.
Args:
event: Event to record. If None, a new Event will be created.
Returns:
The recorded event.
"""
return get_stream().record_event(event)
def wait_event(event: Event):
"""Make the current stream wait for a CUDA event.
Args:
event: Event to wait for.
"""
get_stream().wait_event(event)
def wait_stream(stream: Stream, event: Event = None):
"""Make the current stream wait for another CUDA stream to complete its work.
Args:
event: Event to be used. If None, a new Event will be created.
"""
get_stream().wait_stream(stream, event=event)
class RegisteredGLBuffer:
"""
Helper object to register a GL buffer with CUDA so that it can be mapped to a Warp array.
"""
# Specifies no hints about how this resource will be used.
# It is therefore assumed that this resource will be
# read from and written to by CUDA. This is the default value.
NONE = 0x00
# Specifies that CUDA will not write to this resource.
READ_ONLY = 0x01
# Specifies that CUDA will not read from this resource and will write over the
# entire contents of the resource, so none of the data previously
# stored in the resource will be preserved.
WRITE_DISCARD = 0x02
def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE):
"""Create a new RegisteredGLBuffer object.
Args:
gl_buffer_id: The OpenGL buffer id (GLuint).
device: The device to register the buffer with. If None, the current device will be used.
flags: A combination of the flags constants.
"""
self.gl_buffer_id = gl_buffer_id
self.device = get_device(device)
self.context = self.device.context
self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
def __del__(self):
runtime.core.cuda_graphics_unregister_resource(self.context, self.resource)
def map(self, dtype, shape) -> warp.array:
"""Map the OpenGL buffer to a Warp array.
Args:
dtype: The type of each element in the array.
shape: The shape of the array.
Returns:
A Warp array object representing the mapped OpenGL buffer.
"""
runtime.core.cuda_graphics_map(self.context, self.resource)
ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_size_t)
ptr = ctypes.c_uint64(0)
size = ctypes.c_size_t(0)
runtime.core.cuda_graphics_device_ptr_and_size(
self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
)
return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device, owner=False)
def unmap(self):
"""Unmap the OpenGL buffer."""
runtime.core.cuda_graphics_unmap(self.context, self.resource)
def zeros(
shape: Tuple = None,
dtype=float,
device: Devicelike = None,
requires_grad: bool = False,
pinned: bool = False,
**kwargs,
) -> warp.array:
"""Return a zero-initialized array
Args:
shape: Array dimensions
dtype: Type of each element, e.g.: warp.vec3, warp.mat33, etc
device: Device that array will live on
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
# use the CUDA default stream for synchronous behaviour with other streams
with warp.ScopedStream(arr.device.null_stream):
arr.zero_()
return arr
def zeros_like(
src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
) -> warp.array:
"""Return a zero-initialized array with the same type and dimension of another array
Args:
src: The template array to use for shape, data type, and device
device: The device where the new array will be created (defaults to src.device)
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
arr.zero_()
return arr
def full(
shape: Tuple = None,
value=0,
dtype=Any,
device: Devicelike = None,
requires_grad: bool = False,
pinned: bool = False,
**kwargs,
) -> warp.array:
"""Return an array with all elements initialized to the given value
Args:
shape: Array dimensions
value: Element value
dtype: Type of each element, e.g.: float, warp.vec3, warp.mat33, etc
device: Device that array will live on
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
if dtype == Any:
# determine dtype from value
value_type = type(value)
if value_type == int:
dtype = warp.int32
elif value_type == float:
dtype = warp.float32
elif value_type in warp.types.scalar_types or hasattr(value_type, "_wp_scalar_type_"):
dtype = value_type
elif isinstance(value, warp.codegen.StructInstance):
dtype = value._cls
elif hasattr(value, "__len__"):
# a sequence, assume it's a vector or matrix value
try:
# try to convert to a numpy array first
na = np.array(value, copy=False)
except Exception as e:
raise ValueError(f"Failed to interpret the value as a vector or matrix: {e}")
# determine the scalar type
scalar_type = warp.types.np_dtype_to_warp_type.get(na.dtype)
if scalar_type is None:
raise ValueError(f"Failed to convert {na.dtype} to a Warp data type")
# determine if vector or matrix
if na.ndim == 1:
dtype = warp.types.vector(na.size, scalar_type)
elif na.ndim == 2:
dtype = warp.types.matrix(na.shape, scalar_type)
else:
raise ValueError("Values with more than two dimensions are not supported")
else:
raise ValueError(f"Invalid value type for Warp array: {value_type}")
arr = empty(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
# use the CUDA default stream for synchronous behaviour with other streams
with warp.ScopedStream(arr.device.null_stream):
arr.fill_(value)
return arr
def full_like(
src: warp.array, value: Any, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
) -> warp.array:
"""Return an array with all elements initialized to the given value with the same type and dimension of another array
Args:
src: The template array to use for shape, data type, and device
value: Element value
device: The device where the new array will be created (defaults to src.device)
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
arr.fill_(value)
return arr
def clone(src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None) -> warp.array:
"""Clone an existing array, allocates a copy of the src memory
Args:
src: The source array to copy
device: The device where the new array will be created (defaults to src.device)
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
arr = empty_like(src, device=device, requires_grad=requires_grad, pinned=pinned)
warp.copy(arr, src)
return arr
def empty(
shape: Tuple = None,
dtype=float,
device: Devicelike = None,
requires_grad: bool = False,
pinned: bool = False,
**kwargs,
) -> warp.array:
"""Returns an uninitialized array
Args:
shape: Array dimensions
dtype: Type of each element, e.g.: `warp.vec3`, `warp.mat33`, etc
device: Device that array will live on
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
# backwards compatibility for case where users called wp.empty(n=length, ...)
if "n" in kwargs:
shape = (kwargs["n"],)
del kwargs["n"]
# ensure shape is specified, even if creating a zero-sized array
if shape is None:
shape = 0
return warp.array(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
def empty_like(
src: warp.array, device: Devicelike = None, requires_grad: bool = None, pinned: bool = None
) -> warp.array:
"""Return an uninitialized array with the same type and dimension of another array
Args:
src: The template array to use for shape, data type, and device
device: The device where the new array will be created (defaults to src.device)
requires_grad: Whether the array will be tracked for back propagation
pinned: Whether the array uses pinned host memory (only applicable to CPU arrays)
Returns:
A warp.array object representing the allocation
"""
if device is None:
device = src.device
if requires_grad is None:
if hasattr(src, "requires_grad"):
requires_grad = src.requires_grad
else:
requires_grad = False
if pinned is None:
if hasattr(src, "pinned"):
pinned = src.pinned
else:
pinned = False
arr = empty(shape=src.shape, dtype=src.dtype, device=device, requires_grad=requires_grad, pinned=pinned)
return arr
def from_numpy(
arr: np.ndarray,
dtype: Optional[type] = None,
shape: Optional[Sequence[int]] = None,
device: Optional[Devicelike] = None,
requires_grad: bool = False,
) -> warp.array:
if dtype is None:
base_type = warp.types.np_dtype_to_warp_type.get(arr.dtype)
if base_type is None:
raise RuntimeError("Unsupported NumPy data type '{}'.".format(arr.dtype))
dim_count = len(arr.shape)
if dim_count == 2:
dtype = warp.types.vector(length=arr.shape[1], dtype=base_type)
elif dim_count == 3:
dtype = warp.types.matrix(shape=(arr.shape[1], arr.shape[2]), dtype=base_type)
else:
dtype = base_type
return warp.array(
data=arr,
dtype=dtype,
shape=shape,
owner=False,
device=device,
requires_grad=requires_grad,
)
# given a kernel destination argument type and a value convert
# to a c-type that can be passed to a kernel
def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
if warp.types.is_array(arg_type):
if value is None:
# allow for NULL arrays
return arg_type.__ctype__()
else:
# check for array type
# - in forward passes, array types have to match
# - in backward passes, indexed array gradients are regular arrays
if adjoint:
array_matches = isinstance(value, warp.array)
else:
array_matches = type(value) is type(arg_type)
if not array_matches:
adj = "adjoint " if adjoint else ""
raise RuntimeError(
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array of type {type(arg_type)}, but passed value has type {type(value)}."
)
# check subtype
if not warp.types.types_equal(value.dtype, arg_type.dtype):
adj = "adjoint " if adjoint else ""
raise RuntimeError(
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with dtype={arg_type.dtype} but passed array has dtype={value.dtype}."
)
# check dimensions
if value.ndim != arg_type.ndim:
adj = "adjoint " if adjoint else ""
raise RuntimeError(
f"Error launching kernel '{kernel.key}', {adj}argument '{arg_name}' expects an array with {arg_type.ndim} dimension(s) but the passed array has {value.ndim} dimension(s)."
)
# check device
# if a.device != device and not device.can_access(a.device):
if value.device != device:
raise RuntimeError(
f"Error launching kernel '{kernel.key}', trying to launch on device='{device}', but input array for argument '{arg_name}' is on device={value.device}."
)
return value.__ctype__()
elif isinstance(arg_type, warp.codegen.Struct):
assert value is not None
return value.__ctype__()
# try to convert to a value type (vec3, mat33, etc)
elif issubclass(arg_type, ctypes.Array):
if warp.types.types_equal(type(value), arg_type):
return value
else:
# try constructing the required value from the argument (handles tuple / list, Gf.Vec3 case)
try:
return arg_type(value)
except Exception:
raise ValueError(f"Failed to convert argument for param {arg_name} to {type_str(arg_type)}")
elif isinstance(value, bool):
return ctypes.c_bool(value)
elif isinstance(value, arg_type):
try:
# try to pack as a scalar type
if arg_type is warp.types.float16:
return arg_type._type_(warp.types.float_to_half_bits(value.value))
else:
return arg_type._type_(value.value)
except Exception:
raise RuntimeError(
"Error launching kernel, unable to pack kernel parameter type "
f"{type(value)} for param {arg_name}, expected {arg_type}"
)
else:
try:
# try to pack as a scalar type
if arg_type is warp.types.float16:
return arg_type._type_(warp.types.float_to_half_bits(value))
else:
return arg_type._type_(value)
except Exception as e:
print(e)
raise RuntimeError(
"Error launching kernel, unable to pack kernel parameter type "
f"{type(value)} for param {arg_name}, expected {arg_type}"
)
# represents all data required for a kernel launch
# so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
class Launch:
def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
# if not specified look up hooks
if not hooks:
module = kernel.module
if not module.load(device):
return
hooks = module.get_kernel_hooks(kernel, device)
# if not specified set a zero bound
if not bounds:
bounds = warp.types.launch_bounds_t(0)
# if not specified then build a list of default value params for args
if not params:
params = []
params.append(bounds)
for a in kernel.adj.args:
if isinstance(a.type, warp.types.array):
params.append(a.type.__ctype__())
elif isinstance(a.type, warp.codegen.Struct):
params.append(a.type().__ctype__())
else:
params.append(pack_arg(kernel, a.type, a.label, 0, device, False))
kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
params_addr = kernel_params
self.kernel = kernel
self.hooks = hooks
self.params = params
self.params_addr = params_addr
self.device = device
self.bounds = bounds
self.max_blocks = max_blocks
def set_dim(self, dim):
self.bounds = warp.types.launch_bounds_t(dim)
# launch bounds always at index 0
self.params[0] = self.bounds
# for CUDA kernels we need to update the address to each arg
if self.params_addr:
self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds))
# set kernel param at an index, will convert to ctype as necessary
def set_param_at_index(self, index, value):
arg_type = self.kernel.adj.args[index].type
arg_name = self.kernel.adj.args[index].label
carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, False)
self.params[index + 1] = carg
# for CUDA kernels we need to update the address to each arg
if self.params_addr:
self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(carg))
# set kernel param at an index without any type conversion
# args must be passed as ctypes or basic int / float types
def set_param_at_index_from_ctype(self, index, value):
if isinstance(value, ctypes.Structure):
# not sure how to directly assign struct->struct without reallocating using ctypes
self.params[index + 1] = value
# for CUDA kernels we need to update the address to each arg
if self.params_addr:
self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(value))
else:
self.params[index + 1].__init__(value)
# set kernel param by argument name
def set_param_by_name(self, name, value):
for i, arg in enumerate(self.kernel.adj.args):
if arg.label == name:
self.set_param_at_index(i, value)
# set kernel param by argument name with no type conversions
def set_param_by_name_from_ctype(self, name, value):
# lookup argument index
for i, arg in enumerate(self.kernel.adj.args):
if arg.label == name:
self.set_param_at_index_from_ctype(i, value)
# set all params
def set_params(self, values):
for i, v in enumerate(values):
self.set_param_at_index(i, v)
# set all params without performing type-conversions
def set_params_from_ctypes(self, values):
for i, v in enumerate(values):
self.set_param_at_index_from_ctype(i, v)
def launch(self) -> Any:
if self.device.is_cpu:
self.hooks.forward(*self.params)
else:
runtime.core.cuda_launch_kernel(
self.device.context, self.hooks.forward, self.bounds.size, self.max_blocks, self.params_addr
)
def launch(
kernel,
dim: Tuple[int],
inputs: List,
outputs: List = [],
adj_inputs: List = [],
adj_outputs: List = [],
device: Devicelike = None,
stream: Stream = None,
adjoint=False,
record_tape=True,
record_cmd=False,
max_blocks=0,
):
"""Launch a Warp kernel on the target device
Kernel launches are asynchronous with respect to the calling Python thread.
Args:
kernel: The name of a Warp kernel function, decorated with the ``@wp.kernel`` decorator
dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints with max of 4 dimensions
inputs: The input parameters to the kernel
outputs: The output parameters (optional)
adj_inputs: The adjoint inputs (optional)
adj_outputs: The adjoint outputs (optional)
device: The device to launch on (optional)
stream: The stream to launch on (optional)
adjoint: Whether to run forward or backward pass (typically use False)
record_tape: When true the launch will be recorded the global wp.Tape() object when present
record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
If negative or zero, the maximum hardware value will be used.
"""
assert_initialized()
# if stream is specified, use the associated device
if stream is not None:
device = stream.device
else:
device = runtime.get_device(device)
# check function is a Kernel
if not isinstance(kernel, Kernel):
raise RuntimeError("Error launching kernel, can only launch functions decorated with @wp.kernel.")
# debugging aid
if warp.config.print_launches:
print(f"kernel: {kernel.key} dim: {dim} inputs: {inputs} outputs: {outputs} device: {device}")
# construct launch bounds
bounds = warp.types.launch_bounds_t(dim)
if bounds.size > 0:
# first param is the number of threads
params = []
params.append(bounds)
# converts arguments to kernel's expected ctypes and packs into params
def pack_args(args, params, adjoint=False):
for i, a in enumerate(args):
arg_type = kernel.adj.args[i].type
arg_name = kernel.adj.args[i].label
params.append(pack_arg(kernel, arg_type, arg_name, a, device, adjoint))
fwd_args = inputs + outputs
adj_args = adj_inputs + adj_outputs
if (len(fwd_args)) != (len(kernel.adj.args)):
raise RuntimeError(
f"Error launching kernel '{kernel.key}', passed {len(fwd_args)} arguments but kernel requires {len(kernel.adj.args)}."
)
# if it's a generic kernel, infer the required overload from the arguments
if kernel.is_generic:
fwd_types = kernel.infer_argument_types(fwd_args)
kernel = kernel.get_overload(fwd_types)
# delay load modules, including new overload if needed
module = kernel.module
if not module.load(device):
return
# late bind
hooks = module.get_kernel_hooks(kernel, device)
pack_args(fwd_args, params)
pack_args(adj_args, params, adjoint=True)
# run kernel
if device.is_cpu:
if adjoint:
if hooks.backward is None:
raise RuntimeError(
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
)
hooks.backward(*params)
else:
if hooks.forward is None:
raise RuntimeError(
f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
)
if record_cmd:
launch = Launch(
kernel=kernel, hooks=hooks, params=params, params_addr=None, bounds=bounds, device=device
)
return launch
else:
hooks.forward(*params)
else:
kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
with warp.ScopedStream(stream):
if adjoint:
if hooks.backward is None:
raise RuntimeError(
f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
)
runtime.core.cuda_launch_kernel(
device.context, hooks.backward, bounds.size, max_blocks, kernel_params
)
else:
if hooks.forward is None:
raise RuntimeError(
f"Failed to find forward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
)
if record_cmd:
launch = Launch(
kernel=kernel,
hooks=hooks,
params=params,
params_addr=kernel_params,
bounds=bounds,
device=device,
)
return launch
else:
# launch
runtime.core.cuda_launch_kernel(
device.context, hooks.forward, bounds.size, max_blocks, kernel_params
)
try:
runtime.verify_cuda_device(device)
except Exception as e:
print(f"Error launching kernel: {kernel.key} on device {device}")
raise e
# record on tape if one is active
if runtime.tape and record_tape:
runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device)
def synchronize():
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
This method allows the host application code to ensure that any kernel launches
or memory copies have completed.
"""
if is_cuda_driver_initialized():
# save the original context to avoid side effects
saved_context = runtime.core.cuda_context_get_current()
# TODO: only synchronize devices that have outstanding work
for device in runtime.cuda_devices:
# avoid creating primary context if the device has not been used yet
if device.has_context:
if device.is_capturing:
raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active")
runtime.core.cuda_context_synchronize(device.context)
# restore the original context to avoid side effects
runtime.core.cuda_context_set_current(saved_context)
def synchronize_device(device: Devicelike = None):
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on the specified device
This method allows the host application code to ensure that any kernel launches
or memory copies have completed.
Args:
device: Device to synchronize. If None, synchronize the current CUDA device.
"""
device = runtime.get_device(device)
if device.is_cuda:
if device.is_capturing:
raise RuntimeError(f"Cannot synchronize device {device} while graph capture is active")
runtime.core.cuda_context_synchronize(device.context)
def synchronize_stream(stream_or_device=None):
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on the specified stream.
Args:
stream_or_device: `wp.Stream` or a device. If the argument is a device, synchronize the device's current stream.
"""
if isinstance(stream_or_device, Stream):
stream = stream_or_device
else:
stream = runtime.get_device(stream_or_device).stream
runtime.core.cuda_stream_synchronize(stream.device.context, stream.cuda_stream)
def force_load(device: Union[Device, str, List[Device], List[str]] = None, modules: List[Module] = None):
"""Force user-defined kernels to be compiled and loaded
Args:
device: The device or list of devices to load the modules on. If None, load on all devices.
modules: List of modules to load. If None, load all imported modules.
"""
if is_cuda_driver_initialized():
# save original context to avoid side effects
saved_context = runtime.core.cuda_context_get_current()
if device is None:
devices = get_devices()
elif isinstance(device, list):
devices = [get_device(device_item) for device_item in device]
else:
devices = [get_device(device)]
if modules is None:
modules = user_modules.values()
for d in devices:
for m in modules:
m.load(d)
if is_cuda_available():
# restore original context to avoid side effects
runtime.core.cuda_context_set_current(saved_context)
def load_module(
module: Union[Module, ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
):
"""Force user-defined module to be compiled and loaded
Args:
module: The module to load. If None, load the current module.
device: The device to load the modules on. If None, load on all devices.
recursive: Whether to load submodules. E.g., if the given module is `warp.sim`, this will also load `warp.sim.model`, `warp.sim.articulation`, etc.
Note: A module must be imported before it can be loaded by this function.
"""
if module is None:
# if module not specified, use the module that called us
module = inspect.getmodule(inspect.stack()[1][0])
module_name = module.__name__
elif isinstance(module, Module):
module_name = module.name
elif isinstance(module, ModuleType):
module_name = module.__name__
elif isinstance(module, str):
module_name = module
else:
raise TypeError(f"Argument must be a module, got {type(module)}")
modules = []
# add the given module, if found
m = user_modules.get(module_name)
if m is not None:
modules.append(m)
# add submodules, if recursive
if recursive:
prefix = module_name + "."
for name, mod in user_modules.items():
if name.startswith(prefix):
modules.append(mod)
force_load(device=device, modules=modules)
def set_module_options(options: Dict[str, Any], module: Optional[Any] = None):
"""Set options for the current module.
Options can be used to control runtime compilation and code-generation
for the current module individually. Available options are listed below.
* **mode**: The compilation mode to use, can be "debug", or "release", defaults to the value of ``warp.config.mode``.
* **max_unroll**: The maximum fixed-size loop to unroll (default 16)
Args:
options: Set of key-value option pairs
"""
if module is None:
m = inspect.getmodule(inspect.stack()[1][0])
else:
m = module
get_module(m.__name__).options.update(options)
get_module(m.__name__).unload()
def get_module_options(module: Optional[Any] = None) -> Dict[str, Any]:
"""Returns a list of options for the current module."""
if module is None:
m = inspect.getmodule(inspect.stack()[1][0])
else:
m = module
return get_module(m.__name__).options
def capture_begin(device: Devicelike = None, stream=None, force_module_load=None):
"""Begin capture of a CUDA graph
Captures all subsequent kernel launches and memory operations on CUDA devices.
This can be used to record large numbers of kernels and replay them with low-overhead.
Args:
device: The device to capture on, if None the current CUDA device will be used
stream: The CUDA stream to capture on
force_module_load: Whether or not to force loading of all kernels before capture, in general it is better to use :func:`~warp.load_module()` to selectively load kernels.
"""
if force_module_load is None:
force_module_load = warp.config.graph_capture_module_load_default
if warp.config.verify_cuda:
raise RuntimeError("Cannot use CUDA error verification during graph capture")
if stream is not None:
device = stream.device
else:
device = runtime.get_device(device)
if not device.is_cuda:
raise RuntimeError("Must be a CUDA device")
if force_module_load:
force_load(device)
device.is_capturing = True
# disable garbage collection to avoid older allocations getting collected during graph capture
gc.disable()
with warp.ScopedStream(stream):
runtime.core.cuda_graph_begin_capture(device.context)
def capture_end(device: Devicelike = None, stream=None) -> Graph:
"""Ends the capture of a CUDA graph
Returns:
A handle to a CUDA graph object that can be launched with :func:`~warp.capture_launch()`
"""
if stream is not None:
device = stream.device
else:
device = runtime.get_device(device)
if not device.is_cuda:
raise RuntimeError("Must be a CUDA device")
with warp.ScopedStream(stream):
graph = runtime.core.cuda_graph_end_capture(device.context)
device.is_capturing = False
# re-enable GC
gc.enable()
if graph is None:
raise RuntimeError(
"Error occurred during CUDA graph capture. This could be due to an unintended allocation or CPU/GPU synchronization event."
)
else:
return Graph(device, graph)
def capture_launch(graph: Graph, stream: Stream = None):
"""Launch a previously captured CUDA graph
Args:
graph: A Graph as returned by :func:`~warp.capture_end()`
stream: A Stream to launch the graph on (optional)
"""
if stream is not None:
if stream.device != graph.device:
raise RuntimeError(f"Cannot launch graph from device {graph.device} on stream from device {stream.device}")
device = stream.device
else:
device = graph.device
with warp.ScopedStream(stream):
runtime.core.cuda_graph_launch(device.context, graph.exec)
def copy(
dest: warp.array, src: warp.array, dest_offset: int = 0, src_offset: int = 0, count: int = 0, stream: Stream = None
):
"""Copy array contents from src to dest
Args:
dest: Destination array, must be at least as big as source buffer
src: Source array
dest_offset: Element offset in the destination array
src_offset: Element offset in the source array
count: Number of array elements to copy (will copy all elements if set to 0)
stream: The stream on which to perform the copy (optional)
"""
if not warp.types.is_array(src) or not warp.types.is_array(dest):
raise RuntimeError("Copy source and destination must be arrays")
# backwards compatibility, if count is zero then copy entire src array
if count <= 0:
count = src.size
if count == 0:
return
# copying non-contiguous arrays requires that they are on the same device
if not (src.is_contiguous and dest.is_contiguous) and src.device != dest.device:
if dest.is_contiguous:
# make a contiguous copy of the source array
src = src.contiguous()
else:
# make a copy of the source array on the destination device
src = src.to(dest.device)
if src.is_contiguous and dest.is_contiguous:
bytes_to_copy = count * warp.types.type_size_in_bytes(src.dtype)
src_size_in_bytes = src.size * warp.types.type_size_in_bytes(src.dtype)
dst_size_in_bytes = dest.size * warp.types.type_size_in_bytes(dest.dtype)
src_offset_in_bytes = src_offset * warp.types.type_size_in_bytes(src.dtype)
dst_offset_in_bytes = dest_offset * warp.types.type_size_in_bytes(dest.dtype)
src_ptr = src.ptr + src_offset_in_bytes
dst_ptr = dest.ptr + dst_offset_in_bytes
if src_offset_in_bytes + bytes_to_copy > src_size_in_bytes:
raise RuntimeError(
f"Trying to copy source buffer with size ({bytes_to_copy}) from offset ({src_offset_in_bytes}) is larger than source size ({src_size_in_bytes})"
)
if dst_offset_in_bytes + bytes_to_copy > dst_size_in_bytes:
raise RuntimeError(
f"Trying to copy source buffer with size ({bytes_to_copy}) to offset ({dst_offset_in_bytes}) is larger than destination size ({dst_size_in_bytes})"
)
if src.device.is_cpu and dest.device.is_cpu:
runtime.core.memcpy_h2h(dst_ptr, src_ptr, bytes_to_copy)
else:
# figure out the CUDA context/stream for the copy
if stream is not None:
copy_device = stream.device
elif dest.device.is_cuda:
copy_device = dest.device
else:
copy_device = src.device
with warp.ScopedStream(stream):
if src.device.is_cpu and dest.device.is_cuda:
runtime.core.memcpy_h2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
elif src.device.is_cuda and dest.device.is_cpu:
runtime.core.memcpy_d2h(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
elif src.device.is_cuda and dest.device.is_cuda:
if src.device == dest.device:
runtime.core.memcpy_d2d(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
else:
runtime.core.memcpy_peer(copy_device.context, dst_ptr, src_ptr, bytes_to_copy)
else:
raise RuntimeError("Unexpected source and destination combination")
else:
# handle non-contiguous and indexed arrays
if src.shape != dest.shape:
raise RuntimeError("Incompatible array shapes")
src_elem_size = warp.types.type_size_in_bytes(src.dtype)
dst_elem_size = warp.types.type_size_in_bytes(dest.dtype)
if src_elem_size != dst_elem_size:
raise RuntimeError("Incompatible array data types")
# can't copy to/from fabric arrays of arrays, because they are jagged arrays of arbitrary lengths
# TODO?
if (
isinstance(src, (warp.fabricarray, warp.indexedfabricarray))
and src.ndim > 1
or isinstance(dest, (warp.fabricarray, warp.indexedfabricarray))
and dest.ndim > 1
):
raise RuntimeError("Copying to/from Fabric arrays of arrays is not supported")
src_desc = src.__ctype__()
dst_desc = dest.__ctype__()
src_ptr = ctypes.pointer(src_desc)
dst_ptr = ctypes.pointer(dst_desc)
src_type = warp.types.array_type_id(src)
dst_type = warp.types.array_type_id(dest)
if src.device.is_cuda:
with warp.ScopedStream(stream):
runtime.core.array_copy_device(src.device.context, dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
else:
runtime.core.array_copy_host(dst_ptr, src_ptr, dst_type, src_type, src_elem_size)
# copy gradient, if needed
if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None:
copy(dest.grad, src.grad, stream=stream)
def type_str(t):
if t is None:
return "None"
elif t == Any:
return "Any"
elif t == Callable:
return "Callable"
elif t == Tuple[int, int]:
return "Tuple[int, int]"
elif isinstance(t, int):
return str(t)
elif isinstance(t, List):
return "Tuple[" + ", ".join(map(type_str, t)) + "]"
elif isinstance(t, warp.array):
return f"Array[{type_str(t.dtype)}]"
elif isinstance(t, warp.indexedarray):
return f"IndexedArray[{type_str(t.dtype)}]"
elif isinstance(t, warp.fabricarray):
return f"FabricArray[{type_str(t.dtype)}]"
elif isinstance(t, warp.indexedfabricarray):
return f"IndexedFabricArray[{type_str(t.dtype)}]"
elif hasattr(t, "_wp_generic_type_str_"):
generic_type = t._wp_generic_type_str_
# for concrete vec/mat types use the short name
if t in warp.types.vector_types:
return t.__name__
# for generic vector / matrix type use a Generic type hint
if generic_type == "vec_t":
# return f"Vector"
return f"Vector[{type_str(t._wp_type_params_[0])},{type_str(t._wp_scalar_type_)}]"
elif generic_type == "quat_t":
# return f"Quaternion"
return f"Quaternion[{type_str(t._wp_scalar_type_)}]"
elif generic_type == "mat_t":
# return f"Matrix"
return f"Matrix[{type_str(t._wp_type_params_[0])},{type_str(t._wp_type_params_[1])},{type_str(t._wp_scalar_type_)}]"
elif generic_type == "transform_t":
# return f"Transformation"
return f"Transformation[{type_str(t._wp_scalar_type_)}]"
else:
raise TypeError("Invalid vector or matrix dimensions")
else:
return t.__name__
def print_function(f, file, noentry=False): # pragma: no cover
"""Writes a function definition to a file for use in reST documentation
Args:
f: The function being written
file: The file object for output
noentry: If True, then the :noindex: and :nocontentsentry: directive
options will be added
Returns:
A bool indicating True if f was written to file
"""
if f.hidden:
return False
args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
return_type = ""
try:
# todo: construct a default value for each of the functions args
# so we can generate the return type for overloaded functions
return_type = " -> " + type_str(f.value_func(None, None, None))
except Exception:
pass
print(f".. function:: {f.key}({args}){return_type}", file=file)
if noentry:
print(" :noindex:", file=file)
print(" :nocontentsentry:", file=file)
print("", file=file)
if f.doc != "":
if not f.missing_grad:
print(f" {f.doc}", file=file)
else:
print(f" {f.doc} [1]_", file=file)
print("", file=file)
print(file=file)
return True
def export_functions_rst(file): # pragma: no cover
header = (
"..\n"
" Autogenerated File - Do not edit. Run build_docs.py to generate.\n"
"\n"
".. functions:\n"
".. currentmodule:: warp\n"
"\n"
"Kernel Reference\n"
"================"
)
print(header, file=file)
# type definitions of all functions by group
print("\nScalar Types", file=file)
print("------------", file=file)
for t in warp.types.scalar_types:
print(f".. class:: {t.__name__}", file=file)
# Manually add wp.bool since it's inconvenient to add to wp.types.scalar_types:
print(f".. class:: {warp.types.bool.__name__}", file=file)
print("\n\nVector Types", file=file)
print("------------", file=file)
for t in warp.types.vector_types:
print(f".. class:: {t.__name__}", file=file)
print("\nGeneric Types", file=file)
print("-------------", file=file)
print(".. class:: Int", file=file)
print(".. class:: Float", file=file)
print(".. class:: Scalar", file=file)
print(".. class:: Vector", file=file)
print(".. class:: Matrix", file=file)
print(".. class:: Quaternion", file=file)
print(".. class:: Transformation", file=file)
print(".. class:: Array", file=file)
print("\nQuery Types", file=file)
print("-------------", file=file)
print(".. autoclass:: bvh_query_t", file=file)
print(".. autoclass:: hash_grid_query_t", file=file)
print(".. autoclass:: mesh_query_aabb_t", file=file)
print(".. autoclass:: mesh_query_point_t", file=file)
print(".. autoclass:: mesh_query_ray_t", file=file)
# build dictionary of all functions by group
groups = {}
for k, f in builtin_functions.items():
# build dict of groups
if f.group not in groups:
groups[f.group] = []
# append all overloads to the group
for o in f.overloads:
groups[f.group].append(o)
# Keep track of what function names have been written
written_functions = {}
for k, g in groups.items():
print("\n", file=file)
print(k, file=file)
print("---------------", file=file)
for f in g:
if f.key in written_functions:
# Add :noindex: + :nocontentsentry: since Sphinx gets confused
print_function(f, file=file, noentry=True)
else:
if print_function(f, file=file):
written_functions[f.key] = []
# footnotes
print(".. rubric:: Footnotes", file=file)
print(".. [1] Note: function gradients not implemented for backpropagation.", file=file)
def export_stubs(file): # pragma: no cover
"""Generates stub file for auto-complete of builtin functions"""
import textwrap
print(
"# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
file=file,
)
print("", file=file)
print("from typing import Any", file=file)
print("from typing import Tuple", file=file)
print("from typing import Callable", file=file)
print("from typing import TypeVar", file=file)
print("from typing import Generic", file=file)
print("from typing import overload as over", file=file)
print(file=file)
# type hints, these need to be mirrored into the stubs file
print('Length = TypeVar("Length", bound=int)', file=file)
print('Rows = TypeVar("Rows", bound=int)', file=file)
print('Cols = TypeVar("Cols", bound=int)', file=file)
print('DType = TypeVar("DType")', file=file)
print('Int = TypeVar("Int")', file=file)
print('Float = TypeVar("Float")', file=file)
print('Scalar = TypeVar("Scalar")', file=file)
print("Vector = Generic[Length, Scalar]", file=file)
print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
print("Quaternion = Generic[Float]", file=file)
print("Transformation = Generic[Float]", file=file)
print("Array = Generic[DType]", file=file)
print("FabricArray = Generic[DType]", file=file)
print("IndexedFabricArray = Generic[DType]", file=file)
# prepend __init__.py
with open(os.path.join(os.path.dirname(file.name), "__init__.py")) as header_file:
# strip comment lines
lines = [line for line in header_file if not line.startswith("#")]
header = "".join(lines)
print(header, file=file)
print(file=file)
for k, g in builtin_functions.items():
for f in g.overloads:
args = ", ".join(f"{k}: {type_str(v)}" for k, v in f.input_types.items())
return_str = ""
if not f.export or f.hidden: # or f.generic:
continue
try:
# todo: construct a default value for each of the functions args
# so we can generate the return type for overloaded functions
return_type = f.value_func(None, None, None)
if return_type:
return_str = " -> " + type_str(return_type)
except Exception:
pass
print("@over", file=file)
print(f"def {f.key}({args}){return_str}:", file=file)
print(' """', file=file)
print(textwrap.indent(text=f.doc, prefix=" "), file=file)
print(' """', file=file)
print(" ...\n\n", file=file)
def export_builtins(file: io.TextIOBase): # pragma: no cover
def ctype_arg_str(t):
if isinstance(t, int):
return "int"
elif isinstance(t, float):
return "float"
elif t in warp.types.vector_types:
return f"{t.__name__}&"
else:
return t.__name__
def ctype_ret_str(t):
if isinstance(t, int):
return "int"
elif isinstance(t, float):
return "float"
else:
return t.__name__
file.write("namespace wp {\n\n")
file.write('extern "C" {\n\n')
for k, g in builtin_functions.items():
for f in g.overloads:
if not f.export or f.generic:
continue
simple = True
for k, v in f.input_types.items():
if isinstance(v, warp.array) or v == Any or v == Callable or v == Tuple:
simple = False
break
# only export simple types that don't use arrays
# or templated types
if not simple or f.variadic:
continue
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in f.input_types.items())
params = ", ".join(f.input_types.keys())
return_type = ""
try:
# todo: construct a default value for each of the functions args
# so we can generate the return type for overloaded functions
return_type = ctype_ret_str(f.value_func(None, None, None))
except Exception:
continue
if return_type.startswith("Tuple"):
continue
if args == "":
file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
elif return_type == "None":
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
else:
file.write(
f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
)
file.write('\n} // extern "C"\n\n')
file.write("} // namespace wp\n")
# initialize global runtime
runtime = None
def init():
"""Initialize the Warp runtime. This function must be called before any other API call. If an error occurs an exception will be raised."""
global runtime
if runtime is None:
runtime = Runtime()