qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from __future__ import annotations
import ast
import builtins
import ctypes
import inspect
import re
import sys
import textwrap
import types
from typing import Any, Callable, Mapping
import warp.config
from warp.types import *
class WarpCodegenError(RuntimeError):
def __init__(self, message):
super().__init__(message)
class WarpCodegenTypeError(TypeError):
def __init__(self, message):
super().__init__(message)
class WarpCodegenAttributeError(AttributeError):
def __init__(self, message):
super().__init__(message)
class WarpCodegenKeyError(KeyError):
def __init__(self, message):
super().__init__(message)
# map operator to function name
builtin_operators = {}
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
# nice overview of python operators
builtin_operators[ast.Add] = "add"
builtin_operators[ast.Sub] = "sub"
builtin_operators[ast.Mult] = "mul"
builtin_operators[ast.MatMult] = "mul"
builtin_operators[ast.Div] = "div"
builtin_operators[ast.FloorDiv] = "floordiv"
builtin_operators[ast.Pow] = "pow"
builtin_operators[ast.Mod] = "mod"
builtin_operators[ast.UAdd] = "pos"
builtin_operators[ast.USub] = "neg"
builtin_operators[ast.Not] = "unot"
builtin_operators[ast.Gt] = ">"
builtin_operators[ast.Lt] = "<"
builtin_operators[ast.GtE] = ">="
builtin_operators[ast.LtE] = "<="
builtin_operators[ast.Eq] = "=="
builtin_operators[ast.NotEq] = "!="
builtin_operators[ast.BitAnd] = "bit_and"
builtin_operators[ast.BitOr] = "bit_or"
builtin_operators[ast.BitXor] = "bit_xor"
builtin_operators[ast.Invert] = "invert"
builtin_operators[ast.LShift] = "lshift"
builtin_operators[ast.RShift] = "rshift"
comparison_chain_strings = [
builtin_operators[ast.Gt],
builtin_operators[ast.Lt],
builtin_operators[ast.LtE],
builtin_operators[ast.GtE],
builtin_operators[ast.Eq],
builtin_operators[ast.NotEq],
]
def op_str_is_chainable(op: str) -> builtins.bool:
return op in comparison_chain_strings
def get_annotations(obj: Any) -> Mapping[str, Any]:
"""Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
if isinstance(obj, type):
return obj.__dict__.get("__annotations__", {})
return getattr(obj, "__annotations__", {})
def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
indent = "\t"
# handle empty structs
if len(inst._cls.vars) == 0:
return f"{inst._cls.key}()"
lines = []
lines.append(f"{inst._cls.key}(")
for field_name, _ in inst._cls.ctype._fields_:
field_value = getattr(inst, field_name, None)
if isinstance(field_value, StructInstance):
field_value = struct_instance_repr_recursive(field_value, depth + 1)
lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
lines.append(f"{indent * depth})")
return "\n".join(lines)
class StructInstance:
def __init__(self, cls: Struct, ctype):
super().__setattr__("_cls", cls)
# maintain a c-types object for the top-level instance the struct
if not ctype:
super().__setattr__("_ctype", cls.ctype())
else:
super().__setattr__("_ctype", ctype)
# create Python attributes for each of the struct's variables
for field, var in cls.vars.items():
if isinstance(var.type, warp.codegen.Struct):
self.__dict__[field] = StructInstance(var.type, getattr(self._ctype, field))
elif isinstance(var.type, warp.types.array):
self.__dict__[field] = None
else:
self.__dict__[field] = var.type()
def __setattr__(self, name, value):
if name not in self._cls.vars:
raise RuntimeError(f"Trying to set Warp struct attribute that does not exist {name}")
var = self._cls.vars[name]
# update our ctype flat copy
if isinstance(var.type, array):
if value is None:
# create array with null pointer
setattr(self._ctype, name, array_t())
else:
# wp.array
assert isinstance(value, array)
assert types_equal(
value.dtype, var.type.dtype
), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
setattr(self._ctype, name, value.__ctype__())
elif isinstance(var.type, Struct):
# assign structs by-value, otherwise we would have problematic cases transferring ownership
# of the underlying ctypes data between shared Python struct instances
if not isinstance(value, StructInstance):
raise RuntimeError(
f"Trying to assign a non-structure value to a struct attribute with type: {self._cls.key}"
)
# destination attribution on self
dest = getattr(self, name)
if dest._cls.key is not value._cls.key:
raise RuntimeError(
f"Trying to assign a structure of type {value._cls.key} to an attribute of {self._cls.key}"
)
# update all nested ctype vars by deep copy
for n in dest._cls.vars:
setattr(dest, n, getattr(value, n))
# early return to avoid updating our Python StructInstance
return
elif issubclass(var.type, ctypes.Array):
# vector/matrix type, e.g. vec3
if value is None:
setattr(self._ctype, name, var.type())
elif types_equal(type(value), var.type):
setattr(self._ctype, name, value)
else:
# conversion from list/tuple, ndarray, etc.
setattr(self._ctype, name, var.type(value))
else:
# primitive type
if value is None:
# zero initialize
setattr(self._ctype, name, var.type._type_())
else:
if hasattr(value, "_type_"):
# assigning warp type value (e.g.: wp.float32)
value = value.value
# float16 needs conversion to uint16 bits
if var.type == warp.float16:
setattr(self._ctype, name, float_to_half_bits(value))
else:
setattr(self._ctype, name, value)
# update Python instance
super().__setattr__(name, value)
def __ctype__(self):
return self._ctype
def __repr__(self):
return struct_instance_repr_recursive(self, 0)
# type description used in numpy structured arrays
def numpy_dtype(self):
return self._cls.numpy_dtype()
# value usable in numpy structured arrays of .numpy_dtype(), e.g. (42, 13.37, [1.0, 2.0, 3.0])
def numpy_value(self):
npvalue = []
for name, var in self._cls.vars.items():
# get the attribute value
value = getattr(self._ctype, name)
if isinstance(var.type, array):
# array_t
npvalue.append(value.numpy_value())
elif isinstance(var.type, Struct):
# nested struct
npvalue.append(value.numpy_value())
elif issubclass(var.type, ctypes.Array):
if len(var.type._shape_) == 1:
# vector
npvalue.append(list(value))
else:
# matrix
npvalue.append([list(row) for row in value])
else:
# scalar
if var.type == warp.float16:
npvalue.append(half_bits_to_float(value))
else:
npvalue.append(value)
return tuple(npvalue)
class Struct:
def __init__(self, cls, key, module):
self.cls = cls
self.module = module
self.key = key
self.vars = {}
annotations = get_annotations(self.cls)
for label, type in annotations.items():
self.vars[label] = Var(label, type)
fields = []
for label, var in self.vars.items():
if isinstance(var.type, array):
fields.append((label, array_t))
elif isinstance(var.type, Struct):
fields.append((label, var.type.ctype))
elif issubclass(var.type, ctypes.Array):
fields.append((label, var.type))
else:
fields.append((label, var.type._type_))
class StructType(ctypes.Structure):
# if struct is empty, add a dummy field to avoid launch errors on CPU device ("ffi_prep_cif failed")
_fields_ = fields or [("_dummy_", ctypes.c_byte)]
self.ctype = StructType
# create default constructor (zero-initialize)
self.default_constructor = warp.context.Function(
func=None,
key=self.key,
namespace="",
value_func=lambda *_: self,
input_types={},
initializer_list_func=lambda *_: False,
native_func=make_full_qualified_name(self.cls),
)
# build a constructor that takes each param as a value
input_types = {label: var.type for label, var in self.vars.items()}
self.value_constructor = warp.context.Function(
func=None,
key=self.key,
namespace="",
value_func=lambda *_: self,
input_types=input_types,
initializer_list_func=lambda *_: False,
native_func=make_full_qualified_name(self.cls),
)
self.default_constructor.add_overload(self.value_constructor)
if module:
module.register_struct(self)
def __call__(self):
"""
This function returns s = StructInstance(self)
s uses self.cls as template.
To enable autocomplete on s, we inherit from self.cls.
For example,
@wp.struct
class A:
# annotations
...
The type annotations are inherited in A(), allowing autocomplete in kernels
"""
# return StructInstance(self)
class NewStructInstance(self.cls, StructInstance):
def __init__(inst):
StructInstance.__init__(inst, self, None)
return NewStructInstance()
def initializer(self):
return self.default_constructor
# return structured NumPy dtype, including field names, formats, and offsets
def numpy_dtype(self):
names = []
formats = []
offsets = []
for name, var in self.vars.items():
names.append(name)
offsets.append(getattr(self.ctype, name).offset)
if isinstance(var.type, array):
# array_t
formats.append(array_t.numpy_dtype())
elif isinstance(var.type, Struct):
# nested struct
formats.append(var.type.numpy_dtype())
elif issubclass(var.type, ctypes.Array):
scalar_typestr = type_typestr(var.type._wp_scalar_type_)
if len(var.type._shape_) == 1:
# vector
formats.append(f"{var.type._length_}{scalar_typestr}")
else:
# matrix
formats.append(f"{var.type._shape_}{scalar_typestr}")
else:
# scalar
formats.append(type_typestr(var.type))
return {"names": names, "formats": formats, "offsets": offsets, "itemsize": ctypes.sizeof(self.ctype)}
# constructs a Warp struct instance from a pointer to the ctype
def from_ptr(self, ptr):
if not ptr:
raise RuntimeError("NULL pointer exception")
# create a new struct instance
instance = self()
for name, var in self.vars.items():
offset = getattr(self.ctype, name).offset
if isinstance(var.type, array):
# We could reconstruct wp.array from array_t, but it's problematic.
# There's no guarantee that the original wp.array is still allocated and
# no easy way to make a backref.
# Instead, we just create a stub annotation, which is not a fully usable array object.
setattr(instance, name, array(dtype=var.type.dtype, ndim=var.type.ndim))
elif isinstance(var.type, Struct):
# nested struct
value = var.type.from_ptr(ptr + offset)
setattr(instance, name, value)
elif issubclass(var.type, ctypes.Array):
# vector/matrix
value = var.type.from_ptr(ptr + offset)
setattr(instance, name, value)
else:
# scalar
cvalue = ctypes.cast(ptr + offset, ctypes.POINTER(var.type._type_)).contents
if var.type == warp.float16:
setattr(instance, name, half_bits_to_float(cvalue))
else:
setattr(instance, name, cvalue.value)
return instance
class Reference:
def __init__(self, value_type):
self.value_type = value_type
def is_reference(type):
return isinstance(type, Reference)
def strip_reference(arg):
if is_reference(arg):
return arg.value_type
else:
return arg
def compute_type_str(base_name, template_params):
if not template_params:
return base_name
def param2str(p):
if isinstance(p, int):
return str(p)
elif hasattr(p, "_type_"):
return f"wp::{p.__name__}"
return p.__name__
return f"{base_name}<{','.join(map(param2str, template_params))}>"
class Var:
def __init__(self, label, type, requires_grad=False, constant=None, prefix=True):
# convert built-in types to wp types
if type == float:
type = float32
elif type == int:
type = int32
self.label = label
self.type = type
self.requires_grad = requires_grad
self.constant = constant
self.prefix = prefix
def __str__(self):
return self.label
@staticmethod
def type_to_ctype(t, value_type=False):
if is_array(t):
if hasattr(t.dtype, "_wp_generic_type_str_"):
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
elif isinstance(t.dtype, Struct):
dtypestr = make_full_qualified_name(t.dtype.cls)
elif t.dtype.__name__ in ("bool", "int", "float"):
dtypestr = t.dtype.__name__
else:
dtypestr = f"wp::{t.dtype.__name__}"
classstr = f"wp::{type(t).__name__}"
return f"{classstr}_t<{dtypestr}>"
elif isinstance(t, Struct):
return make_full_qualified_name(t.cls)
elif is_reference(t):
if not value_type:
return Var.type_to_ctype(t.value_type) + "*"
else:
return Var.type_to_ctype(t.value_type)
elif hasattr(t, "_wp_generic_type_str_"):
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
elif t.__name__ in ("bool", "int", "float"):
return t.__name__
else:
return f"wp::{t.__name__}"
def ctype(self, value_type=False):
return Var.type_to_ctype(self.type, value_type)
def emit(self, prefix: str = "var"):
if self.prefix:
return f"{prefix}_{self.label}"
else:
return self.label
def emit_adj(self):
return self.emit("adj")
class Block:
# Represents a basic block of instructions, e.g.: list
# of straight line instructions inside a for-loop or conditional
def __init__(self):
# list of statements inside this block
self.body_forward = []
self.body_replay = []
self.body_reverse = []
# list of vars declared in this block
self.vars = []
class Adjoint:
# Source code transformer, this class takes a Python function and
# generates forward and backward SSA forms of the function instructions
def __init__(
adj,
func,
overload_annotations=None,
is_user_function=False,
skip_forward_codegen=False,
skip_reverse_codegen=False,
custom_reverse_mode=False,
custom_reverse_num_input_args=-1,
transformers: List[ast.NodeTransformer] = [],
):
adj.func = func
adj.is_user_function = is_user_function
# whether the generation of the forward code is skipped for this function
adj.skip_forward_codegen = skip_forward_codegen
# whether the generation of the adjoint code is skipped for this function
adj.skip_reverse_codegen = skip_reverse_codegen
# extract name of source file
adj.filename = inspect.getsourcefile(func) or "unknown source file"
# get source file line number where function starts
_, adj.fun_lineno = inspect.getsourcelines(func)
# get function source code
adj.source = inspect.getsource(func)
# ensures that indented class methods can be parsed as kernels
adj.source = textwrap.dedent(adj.source)
adj.source_lines = adj.source.splitlines()
# build AST and apply node transformers
adj.tree = ast.parse(adj.source)
adj.transformers = transformers
for transformer in transformers:
adj.tree = transformer.visit(adj.tree)
adj.fun_name = adj.tree.body[0].name
# for keeping track of line number in function code
adj.lineno = None
# whether the forward code shall be used for the reverse pass and a custom
# function signature is applied to the reverse version of the function
adj.custom_reverse_mode = custom_reverse_mode
# the number of function arguments that pertain to the forward function
# input arguments (i.e. the number of arguments that are not adjoint arguments)
adj.custom_reverse_num_input_args = custom_reverse_num_input_args
# parse argument types
argspec = inspect.getfullargspec(func)
# ensure all arguments are annotated
if overload_annotations is None:
# use source-level argument annotations
if len(argspec.annotations) < len(argspec.args):
raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
adj.arg_types = argspec.annotations
else:
# use overload argument annotations
for arg_name in argspec.args:
if arg_name not in overload_annotations:
raise WarpCodegenError(f"Incomplete overload annotations for function {adj.fun_name}")
adj.arg_types = overload_annotations.copy()
adj.args = []
adj.symbols = {}
for name, type in adj.arg_types.items():
# skip return hint
if name == "return":
continue
# add variable for argument
arg = Var(name, type, False)
adj.args.append(arg)
# pre-populate symbol dictionary with function argument names
# this is to avoid registering false references to overshadowed modules
adj.symbols[name] = arg
# There are cases where a same module might be rebuilt multiple times,
# for example when kernels are nested inside of functions, or when
# a kernel's launch raises an exception. Ideally we'd always want to
# avoid rebuilding kernels but some corner cases seem to depend on it,
# so we only avoid rebuilding kernels that errored out to give a chance
# for unit testing errors being spit out from kernels.
adj.skip_build = False
# generate function ssa form and adjoint
def build(adj, builder):
if adj.skip_build:
return
adj.builder = builder
adj.symbols = {} # map from symbols to adjoint variables
adj.variables = [] # list of local variables (in order)
adj.return_var = None # return type for function or kernel
adj.loop_symbols = [] # symbols at the start of each loop
# blocks
adj.blocks = [Block()]
adj.loop_blocks = []
# holds current indent level
adj.indentation = ""
# used to generate new label indices
adj.label_count = 0
# update symbol map for each argument
for a in adj.args:
adj.symbols[a.label] = a
# recursively evaluate function body
try:
adj.eval(adj.tree.body[0])
except Exception as e:
try:
if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
else:
msg = "Error"
lineno = adj.lineno + adj.fun_lineno
line = adj.source_lines[adj.lineno]
msg += f' while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
ex, data, traceback = sys.exc_info()
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
finally:
adj.skip_build = True
raise e
if builder is not None:
for a in adj.args:
if isinstance(a.type, Struct):
builder.build_struct_recursive(a.type)
elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
builder.build_struct_recursive(a.type.dtype)
# code generation methods
def format_template(adj, template, input_vars, output_var):
# output var is always the 0th index
args = [output_var] + input_vars
s = template.format(*args)
return s
# generates a list of formatted args
def format_args(adj, prefix, args):
arg_strs = []
for a in args:
if isinstance(a, warp.context.Function):
# functions don't have a var_ prefix so strip it off here
if prefix == "var":
arg_strs.append(a.key)
else:
arg_strs.append(f"{prefix}_{a.key}")
elif is_reference(a.type):
arg_strs.append(f"{prefix}_{a}")
elif isinstance(a, Var):
arg_strs.append(a.emit(prefix))
else:
raise WarpCodegenTypeError(f"Arguments must be variables or functions, got {type(a)}")
return arg_strs
# generates argument string for a forward function call
def format_forward_call_args(adj, args, use_initializer_list):
arg_str = ", ".join(adj.format_args("var", args))
if use_initializer_list:
return f"{{{arg_str}}}"
return arg_str
# generates argument string for a reverse function call
def format_reverse_call_args(
adj,
args_var,
args,
args_out,
use_initializer_list,
has_output_args=True,
require_original_output_arg=False,
):
formatted_var = adj.format_args("var", args_var)
formatted_out = []
if has_output_args and (require_original_output_arg or len(args_out) > 1):
formatted_out = adj.format_args("var", args_out)
formatted_var_adj = adj.format_args(
"&adj" if use_initializer_list else "adj",
args,
)
formatted_out_adj = adj.format_args("adj", args_out)
if len(formatted_var_adj) == 0 and len(formatted_out_adj) == 0:
# there are no adjoint arguments, so we don't need to call the reverse function
return None
if use_initializer_list:
var_str = f"{{{', '.join(formatted_var)}}}"
out_str = f"{{{', '.join(formatted_out)}}}"
adj_str = f"{{{', '.join(formatted_var_adj)}}}"
out_adj_str = ", ".join(formatted_out_adj)
if len(args_out) > 1:
arg_str = ", ".join([var_str, out_str, adj_str, out_adj_str])
else:
arg_str = ", ".join([var_str, adj_str, out_adj_str])
else:
arg_str = ", ".join(formatted_var + formatted_out + formatted_var_adj + formatted_out_adj)
return arg_str
def indent(adj):
adj.indentation = adj.indentation + " "
def dedent(adj):
adj.indentation = adj.indentation[:-4]
def begin_block(adj):
b = Block()
# give block a unique id
b.label = adj.label_count
adj.label_count += 1
adj.blocks.append(b)
return b
def end_block(adj):
return adj.blocks.pop()
def add_var(adj, type=None, constant=None):
index = len(adj.variables)
name = str(index)
# allocate new variable
v = Var(name, type=type, constant=constant)
adj.variables.append(v)
adj.blocks[-1].vars.append(v)
return v
# append a statement to the forward pass
def add_forward(adj, statement, replay=None, skip_replay=False):
adj.blocks[-1].body_forward.append(adj.indentation + statement)
if not skip_replay:
if replay:
# if custom replay specified then output it
adj.blocks[-1].body_replay.append(adj.indentation + replay)
else:
# by default just replay the original statement
adj.blocks[-1].body_replay.append(adj.indentation + statement)
# append a statement to the reverse pass
def add_reverse(adj, statement):
adj.blocks[-1].body_reverse.append(adj.indentation + statement)
def add_constant(adj, n):
output = adj.add_var(type=type(n), constant=n)
return output
def load(adj, var):
if is_reference(var.type):
var = adj.add_builtin_call("load", [var])
return var
def add_comp(adj, op_strings, left, comps):
output = adj.add_var(builtins.bool)
left = adj.load(left)
s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
prev_comp = None
for op, comp in zip(op_strings, comps):
comp_chainable = op_str_is_chainable(op)
if comp_chainable and prev_comp:
# We restrict chaining to operands of the same type
if prev_comp.type is comp.type:
prev_comp = adj.load(prev_comp)
comp = adj.load(comp)
s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
else:
raise WarpCodegenTypeError(
f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
)
else:
comp = adj.load(comp)
s += op + " " + comp.emit() + ") "
prev_comp = comp
s = s.rstrip() + ";"
adj.add_forward(s)
return output
def add_bool_op(adj, op_string, exprs):
exprs = [adj.load(expr) for expr in exprs]
output = adj.add_var(builtins.bool)
command = output.emit() + " = " + (" " + op_string + " ").join([expr.emit() for expr in exprs]) + ";"
adj.add_forward(command)
return output
def resolve_func(adj, func, args, min_outputs, templates, kwds):
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
if not func.is_builtin():
# user-defined function
overload = func.get_overload(arg_types)
if overload is not None:
return overload
else:
# if func is overloaded then perform overload resolution here
# we validate argument types before they go to generated native code
for f in func.overloads:
# skip type checking for variadic functions
if not f.variadic:
# check argument counts match are compatible (may be some default args)
if len(f.input_types) < len(args):
continue
def match_args(args, f):
# check argument types equal
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
# if arg type registered as Any, treat as
# template allowing any type to match
if arg_type == Any:
continue
# handle function refs as a special case
if arg_type == Callable and type(args[i]) is warp.context.Function:
continue
if arg_type == Reference and is_reference(args[i].type):
continue
# look for default values for missing args
if i >= len(args):
if arg_name not in f.defaults:
return False
else:
# otherwise check arg type matches input variable type
if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
return False
return True
if not match_args(args, f):
continue
# check output dimensions match expectations
if min_outputs:
try:
value_type = f.value_func(args, kwds, templates)
if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
continue
except Exception:
# value func may fail if the user has given
# incorrect args, so we need to catch this
continue
# found a match, use it
return f
# unresolved function, report error
arg_types = []
for x in args:
if isinstance(x, Var):
# shorten Warp primitive type names
if isinstance(x.type, list):
if len(x.type) != 1:
raise WarpCodegenError("Argument must not be the result from a multi-valued function")
arg_type = x.type[0]
else:
arg_type = x.type
arg_types.append(type_repr(arg_type))
if isinstance(x, warp.context.Function):
arg_types.append("function")
raise WarpCodegenError(
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_types)}]"
)
def add_call(adj, func, args, min_outputs=None, templates=[], kwds=None):
func = adj.resolve_func(func, args, min_outputs, templates, kwds)
# push any default values onto args
for i, (arg_name, arg_type) in enumerate(func.input_types.items()):
if i >= len(args):
if arg_name in func.defaults:
const = adj.add_constant(func.defaults[arg_name])
args.append(const)
else:
break
# if it is a user-function then build it recursively
if not func.is_builtin():
adj.builder.build_function(func)
# evaluate the function type based on inputs
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
return_type = func.value_func(arg_types, kwds, templates)
func_name = compute_type_str(func.native_func, templates)
param_types = list(func.input_types.values())
use_initializer_list = func.initializer_list_func(args, templates)
args_var = [
adj.load(a)
if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
else a
for i, a in enumerate(args)
]
if return_type is None:
# handles expression (zero output) functions, e.g.: void do_something();
output = None
output_list = []
forward_call = (
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
)
replay_call = forward_call
if func.custom_replay_func is not None:
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
elif not isinstance(return_type, list) or len(return_type) == 1:
# handle simple function (one output)
if isinstance(return_type, list):
return_type = return_type[0]
output = adj.add_var(return_type)
output_list = [output]
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
replay_call = forward_call
if func.custom_replay_func is not None:
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
else:
# handle multiple value functions
output = [adj.add_var(v) for v in return_type]
output_list = output
forward_call = (
f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var + output, use_initializer_list)});"
)
replay_call = forward_call
if func.skip_replay:
adj.add_forward(forward_call, replay="// " + replay_call)
else:
adj.add_forward(forward_call, replay=replay_call)
if not func.missing_grad and len(args):
reverse_has_output_args = (
func.require_original_output_arg or len(output_list) > 1
) and func.custom_grad_func is None
arg_str = adj.format_reverse_call_args(
args_var,
args,
output_list,
use_initializer_list,
has_output_args=reverse_has_output_args,
require_original_output_arg=func.require_original_output_arg,
)
if arg_str is not None:
reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
adj.add_reverse(reverse_call)
return output
def add_builtin_call(adj, func_name, args, min_outputs=None, templates=[], kwds=None):
func = warp.context.builtin_functions[func_name]
return adj.add_call(func, args, min_outputs, templates, kwds)
def add_return(adj, var):
if var is None or len(var) == 0:
adj.add_forward("return;", f"goto label{adj.label_count};")
elif len(var) == 1:
adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
adj.add_reverse("adj_" + str(var[0]) + " += adj_ret;")
else:
for i, v in enumerate(var):
adj.add_forward(f"ret_{i} = {v.emit()};")
adj.add_reverse(f"adj_{v} += adj_ret_{i};")
adj.add_forward("return;", f"goto label{adj.label_count};")
adj.add_reverse(f"label{adj.label_count}:;")
adj.label_count += 1
# define an if statement
def begin_if(adj, cond):
cond = adj.load(cond)
adj.add_forward(f"if ({cond.emit()}) {{")
adj.add_reverse("}")
adj.indent()
def end_if(adj, cond):
adj.dedent()
adj.add_forward("}")
cond = adj.load(cond)
adj.add_reverse(f"if ({cond.emit()}) {{")
def begin_else(adj, cond):
cond = adj.load(cond)
adj.add_forward(f"if (!{cond.emit()}) {{")
adj.add_reverse("}")
adj.indent()
def end_else(adj, cond):
adj.dedent()
adj.add_forward("}")
cond = adj.load(cond)
adj.add_reverse(f"if (!{cond.emit()}) {{")
# define a for-loop
def begin_for(adj, iter):
cond_block = adj.begin_block()
adj.loop_blocks.append(cond_block)
adj.add_forward(f"for_start_{cond_block.label}:;")
adj.indent()
# evaluate cond
adj.add_forward(f"if (iter_cmp({iter.emit()}) == 0) goto for_end_{cond_block.label};")
# evaluate iter
val = adj.add_builtin_call("iter_next", [iter])
adj.begin_block()
return val
def end_for(adj, iter):
body_block = adj.end_block()
cond_block = adj.end_block()
adj.loop_blocks.pop()
####################
# forward pass
for i in cond_block.body_forward:
adj.blocks[-1].body_forward.append(i)
for i in body_block.body_forward:
adj.blocks[-1].body_forward.append(i)
adj.add_forward(f"goto for_start_{cond_block.label};", skip_replay=True)
adj.dedent()
adj.add_forward(f"for_end_{cond_block.label}:;", skip_replay=True)
####################
# reverse pass
reverse = []
# reverse iterator
reverse.append(adj.indentation + f"{iter.emit()} = wp::iter_reverse({iter.emit()});")
for i in cond_block.body_forward:
reverse.append(i)
# zero adjoints
for i in body_block.vars:
reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
# replay
for i in body_block.body_replay:
reverse.append(i)
# reverse
for i in reversed(body_block.body_reverse):
reverse.append(i)
reverse.append(adj.indentation + f"\tgoto for_start_{cond_block.label};")
reverse.append(adj.indentation + f"for_end_{cond_block.label}:;")
adj.blocks[-1].body_reverse.extend(reversed(reverse))
# define a while loop
def begin_while(adj, cond):
# evaluate condition in its own block
# so we can control replay
cond_block = adj.begin_block()
adj.loop_blocks.append(cond_block)
cond_block.body_forward.append(f"while_start_{cond_block.label}:;")
c = adj.eval(cond)
cond_block.body_forward.append(f"if (({c.emit()}) == false) goto while_end_{cond_block.label};")
# being block around loop
adj.begin_block()
adj.indent()
def end_while(adj):
adj.dedent()
body_block = adj.end_block()
cond_block = adj.end_block()
adj.loop_blocks.pop()
####################
# forward pass
for i in cond_block.body_forward:
adj.blocks[-1].body_forward.append(i)
for i in body_block.body_forward:
adj.blocks[-1].body_forward.append(i)
adj.blocks[-1].body_forward.append(f"goto while_start_{cond_block.label};")
adj.blocks[-1].body_forward.append(f"while_end_{cond_block.label}:;")
####################
# reverse pass
reverse = []
# cond
for i in cond_block.body_forward:
reverse.append(i)
# zero adjoints of local vars
for i in body_block.vars:
reverse.append(f"{i.emit_adj()} = {{}};")
# replay
for i in body_block.body_replay:
reverse.append(i)
# reverse
for i in reversed(body_block.body_reverse):
reverse.append(i)
reverse.append(f"goto while_start_{cond_block.label};")
reverse.append(f"while_end_{cond_block.label}:;")
# output
adj.blocks[-1].body_reverse.extend(reversed(reverse))
def emit_FunctionDef(adj, node):
for f in node.body:
adj.eval(f)
if adj.return_var is not None and len(adj.return_var) == 1:
if not isinstance(node.body[-1], ast.Return):
adj.add_forward("return {};", skip_replay=True)
def emit_If(adj, node):
if len(node.body) == 0:
return None
# eval condition
cond = adj.eval(node.test)
# save symbol map
symbols_prev = adj.symbols.copy()
# eval body
adj.begin_if(cond)
for stmt in node.body:
adj.eval(stmt)
adj.end_if(cond)
# detect existing symbols with conflicting definitions (variables assigned inside the branch)
# and resolve with a phi (select) function
for items in symbols_prev.items():
sym = items[0]
var1 = items[1]
var2 = adj.symbols[sym]
if var1 != var2:
# insert a phi function that selects var1, var2 based on cond
out = adj.add_builtin_call("select", [cond, var1, var2])
adj.symbols[sym] = out
symbols_prev = adj.symbols.copy()
# evaluate 'else' statement as if (!cond)
if len(node.orelse) > 0:
adj.begin_else(cond)
for stmt in node.orelse:
adj.eval(stmt)
adj.end_else(cond)
# detect existing symbols with conflicting definitions (variables assigned inside the else)
# and resolve with a phi (select) function
for items in symbols_prev.items():
sym = items[0]
var1 = items[1]
var2 = adj.symbols[sym]
if var1 != var2:
# insert a phi function that selects var1, var2 based on cond
# note the reversed order of vars since we want to use !cond as our select
out = adj.add_builtin_call("select", [cond, var2, var1])
adj.symbols[sym] = out
def emit_Compare(adj, node):
# node.left, node.ops (list of ops), node.comparators (things to compare to)
# e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
left = adj.eval(node.left)
comps = [adj.eval(comp) for comp in node.comparators]
op_strings = [builtin_operators[type(op)] for op in node.ops]
return adj.add_comp(op_strings, left, comps)
def emit_BoolOp(adj, node):
# op, expr list values
op = node.op
if isinstance(op, ast.And):
func = "&&"
elif isinstance(op, ast.Or):
func = "||"
else:
raise WarpCodegenKeyError(f"Op {op} is not supported")
return adj.add_bool_op(func, [adj.eval(expr) for expr in node.values])
def emit_Name(adj, node):
# lookup symbol, if it has already been assigned to a variable then return the existing mapping
if node.id in adj.symbols:
return adj.symbols[node.id]
# try and resolve the name using the function's globals context (used to lookup constants + functions)
obj = adj.func.__globals__.get(node.id)
if obj is None:
# Lookup constant in captured contents
capturedvars = dict(
zip(adj.func.__code__.co_freevars, [c.cell_contents for c in (adj.func.__closure__ or [])])
)
obj = capturedvars.get(str(node.id), None)
if obj is None:
raise WarpCodegenKeyError("Referencing undefined symbol: " + str(node.id))
if warp.types.is_value(obj):
# evaluate constant
out = adj.add_constant(obj)
adj.symbols[node.id] = out
return out
# the named object is either a function, class name, or module
# pass it back to the caller for processing
return obj
@staticmethod
def resolve_type_attribute(var_type: type, attr: str):
if isinstance(var_type, type) and type_is_value(var_type):
if attr == "dtype":
return type_scalar_type(var_type)
elif attr == "length":
return type_length(var_type)
return getattr(var_type, attr, None)
def vector_component_index(adj, component, vector_type):
if len(component) != 1:
raise WarpCodegenAttributeError(f"Vector swizzle must be single character, got .{component}")
dim = vector_type._shape_[0]
swizzles = "xyzw"[0:dim]
if component not in swizzles:
raise WarpCodegenAttributeError(
f"Vector swizzle for {vector_type} must be one of {swizzles}, got {component}"
)
index = swizzles.index(component)
index = adj.add_constant(index)
return index
@staticmethod
def is_differentiable_value_type(var_type):
# checks that the argument type is a value type (i.e, not an array)
# possibly holding differentiable values (for which gradients must be accumulated)
return type_scalar_type(var_type) in float_types or isinstance(var_type, Struct)
def emit_Attribute(adj, node):
if hasattr(node, "is_adjoint"):
node.value.is_adjoint = True
aggregate = adj.eval(node.value)
try:
if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
out = getattr(aggregate, node.attr)
if warp.types.is_value(out):
return adj.add_constant(out)
return out
if hasattr(node, "is_adjoint"):
# create a Var that points to the struct attribute, i.e.: directly generates `struct.attr` when used
attr_name = aggregate.label + "." + node.attr
attr_type = aggregate.type.vars[node.attr].type
return Var(attr_name, attr_type)
aggregate_type = strip_reference(aggregate.type)
# reading a vector component
if type_is_vector(aggregate_type):
index = adj.vector_component_index(node.attr, aggregate_type)
return adj.add_builtin_call("extract", [aggregate, index])
else:
attr_type = Reference(aggregate_type.vars[node.attr].type)
attr = adj.add_var(attr_type)
if is_reference(aggregate.type):
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
else:
adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
if adj.is_differentiable_value_type(strip_reference(attr_type)):
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
else:
adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
return attr
except (KeyError, AttributeError):
# Try resolving as type attribute
aggregate_type = strip_reference(aggregate.type) if isinstance(aggregate, Var) else aggregate
type_attribute = adj.resolve_type_attribute(aggregate_type, node.attr)
if type_attribute is not None:
return type_attribute
if isinstance(aggregate, Var):
raise WarpCodegenAttributeError(
f"Error, `{node.attr}` is not an attribute of '{node.value.id}' ({type_repr(aggregate.type)})"
)
raise WarpCodegenAttributeError(f"Error, `{node.attr}` is not an attribute of '{aggregate}'")
def emit_String(adj, node):
# string constant
return adj.add_constant(node.s)
def emit_Num(adj, node):
# lookup constant, if it has already been assigned then return existing var
key = (node.n, type(node.n))
if key in adj.symbols:
return adj.symbols[key]
else:
out = adj.add_constant(node.n)
adj.symbols[key] = out
return out
def emit_Ellipsis(adj, node):
# stubbed @wp.native_func
return
def emit_NameConstant(adj, node):
if node.value:
return adj.add_constant(True)
elif node.value is None:
raise WarpCodegenTypeError("None type unsupported")
else:
return adj.add_constant(False)
def emit_Constant(adj, node):
if isinstance(node, ast.Str):
return adj.emit_String(node)
elif isinstance(node, ast.Num):
return adj.emit_Num(node)
elif isinstance(node, ast.Ellipsis):
return adj.emit_Ellipsis(node)
else:
assert isinstance(node, ast.NameConstant)
return adj.emit_NameConstant(node)
def emit_BinOp(adj, node):
# evaluate binary operator arguments
left = adj.eval(node.left)
right = adj.eval(node.right)
name = builtin_operators[type(node.op)]
return adj.add_builtin_call(name, [left, right])
def emit_UnaryOp(adj, node):
# evaluate unary op arguments
arg = adj.eval(node.operand)
name = builtin_operators[type(node.op)]
return adj.add_builtin_call(name, [arg])
def materialize_redefinitions(adj, symbols):
# detect symbols with conflicting definitions (assigned inside the for loop)
for items in symbols.items():
sym = items[0]
var1 = items[1]
var2 = adj.symbols[sym]
if var1 != var2:
if warp.config.verbose and not adj.custom_reverse_mode:
lineno = adj.lineno + adj.fun_lineno
line = adj.source_lines[adj.lineno]
msg = f'Warning: detected mutated variable {sym} during a dynamic for-loop in function "{adj.fun_name}" at {adj.filename}:{lineno}: this may not be a differentiable operation.\n{line}\n'
print(msg)
if var1.constant is not None:
raise WarpCodegenError(
f"Error mutating a constant {sym} inside a dynamic loop, use the following syntax: pi = float(3.141) to declare a dynamic variable"
)
# overwrite the old variable value (violates SSA)
adj.add_builtin_call("assign", [var1, var2])
# reset the symbol to point to the original variable
adj.symbols[sym] = var1
def emit_While(adj, node):
adj.begin_while(node.test)
adj.loop_symbols.append(adj.symbols.copy())
# eval body
for s in node.body:
adj.eval(s)
adj.materialize_redefinitions(adj.loop_symbols[-1])
adj.loop_symbols.pop()
adj.end_while()
def eval_num(adj, a):
if isinstance(a, ast.Num):
return True, a.n
if isinstance(a, ast.UnaryOp) and isinstance(a.op, ast.USub) and isinstance(a.operand, ast.Num):
return True, -a.operand.n
# try and resolve the expression to an object
# e.g.: wp.constant in the globals scope
obj, _ = adj.resolve_static_expression(a)
if isinstance(obj, Var) and obj.constant is not None:
obj = obj.constant
return warp.types.is_int(obj), obj
# detects whether a loop contains a break (or continue) statement
def contains_break(adj, body):
for s in body:
if isinstance(s, ast.Break):
return True
elif isinstance(s, ast.Continue):
return True
elif isinstance(s, ast.If):
if adj.contains_break(s.body):
return True
if adj.contains_break(s.orelse):
return True
else:
# note that nested for or while loops containing a break statement
# do not affect the current loop
pass
return False
# returns a constant range() if unrollable, otherwise None
def get_unroll_range(adj, loop):
if (
not isinstance(loop.iter, ast.Call)
or not isinstance(loop.iter.func, ast.Name)
or loop.iter.func.id != "range"
or len(loop.iter.args) == 0
or len(loop.iter.args) > 3
):
return None
# if all range() arguments are numeric constants we will unroll
# note that this only handles trivial constants, it will not unroll
# constant compile-time expressions e.g.: range(0, 3*2)
# Evaluate the arguments and check that they are numeric constants
# It is important to do that in one pass, so that if evaluating these arguments have side effects
# the code does not get generated more than once
range_args = [adj.eval_num(arg) for arg in loop.iter.args]
arg_is_numeric, arg_values = zip(*range_args)
if all(arg_is_numeric):
# All argument are numeric constants
# range(end)
if len(loop.iter.args) == 1:
start = 0
end = arg_values[0]
step = 1
# range(start, end)
elif len(loop.iter.args) == 2:
start = arg_values[0]
end = arg_values[1]
step = 1
# range(start, end, step)
elif len(loop.iter.args) == 3:
start = arg_values[0]
end = arg_values[1]
step = arg_values[2]
# test if we're above max unroll count
max_iters = abs(end - start) // abs(step)
max_unroll = adj.builder.options["max_unroll"]
ok_to_unroll = True
if max_iters > max_unroll:
if warp.config.verbose:
print(
f"Warning: fixed-size loop count of {max_iters} is larger than the module 'max_unroll' limit of {max_unroll}, will generate dynamic loop."
)
ok_to_unroll = False
elif adj.contains_break(loop.body):
if warp.config.verbose:
print("Warning: 'break' or 'continue' found in loop body, will generate dynamic loop.")
ok_to_unroll = False
if ok_to_unroll:
return range(start, end, step)
# Unroll is not possible, range needs to be valuated dynamically
range_call = adj.add_builtin_call(
"range",
[adj.add_constant(val) if is_numeric else val for is_numeric, val in range_args],
)
return range_call
def emit_For(adj, node):
# try and unroll simple range() statements that use constant args
unroll_range = adj.get_unroll_range(node)
if isinstance(unroll_range, range):
for i in unroll_range:
const_iter = adj.add_constant(i)
var_iter = adj.add_builtin_call("int", [const_iter])
adj.symbols[node.target.id] = var_iter
# eval body
for s in node.body:
adj.eval(s)
# otherwise generate a dynamic loop
else:
# evaluate the Iterable -- only if not previously evaluated when trying to unroll
if unroll_range is not None:
# Range has already been evaluated when trying to unroll, do not re-evaluate
iter = unroll_range
else:
iter = adj.eval(node.iter)
adj.symbols[node.target.id] = adj.begin_for(iter)
# for loops should be side-effect free, here we store a copy
adj.loop_symbols.append(adj.symbols.copy())
# eval body
for s in node.body:
adj.eval(s)
adj.materialize_redefinitions(adj.loop_symbols[-1])
adj.loop_symbols.pop()
adj.end_for(iter)
def emit_Break(adj, node):
adj.materialize_redefinitions(adj.loop_symbols[-1])
adj.add_forward(f"goto for_end_{adj.loop_blocks[-1].label};")
def emit_Continue(adj, node):
adj.materialize_redefinitions(adj.loop_symbols[-1])
adj.add_forward(f"goto for_start_{adj.loop_blocks[-1].label};")
def emit_Expr(adj, node):
return adj.eval(node.value)
def check_tid_in_func_error(adj, node):
if adj.is_user_function:
if hasattr(node.func, "attr") and node.func.attr == "tid":
lineno = adj.lineno + adj.fun_lineno
line = adj.source_lines[adj.lineno]
raise WarpCodegenError(
"tid() may only be called from a Warp kernel, not a Warp function. "
"Instead, obtain the indices from a @wp.kernel and pass them as "
f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
)
def emit_Call(adj, node):
adj.check_tid_in_func_error(node)
# try and lookup function in globals by
# resolving path (e.g.: module.submodule.attr)
func, path = adj.resolve_static_expression(node.func)
templates = []
if not isinstance(func, warp.context.Function):
if len(path) == 0:
raise WarpCodegenError(f"Unknown function or operator: '{node.func.func.id}'")
attr = path[-1]
caller = func
func = None
# try and lookup function name in builtins (e.g.: using `dot` directly without wp prefix)
if attr in warp.context.builtin_functions:
func = warp.context.builtin_functions[attr]
# vector class type e.g.: wp.vec3f constructor
if func is None and hasattr(caller, "_wp_generic_type_str_"):
templates = caller._wp_type_params_
func = warp.context.builtin_functions.get(caller._wp_constructor_)
# scalar class type e.g.: wp.int8 constructor
if func is None and hasattr(caller, "__name__") and caller.__name__ in warp.context.builtin_functions:
func = warp.context.builtin_functions.get(caller.__name__)
# struct constructor
if func is None and isinstance(caller, Struct):
adj.builder.build_struct_recursive(caller)
func = caller.initializer()
if func is None:
raise WarpCodegenError(
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
)
args = []
# eval all arguments
for arg in node.args:
var = adj.eval(arg)
args.append(var)
# eval all keyword ags
def kwval(kw):
if isinstance(kw.value, ast.Num):
return kw.value.n
elif isinstance(kw.value, ast.Tuple):
arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
if not all(arg_is_numeric):
raise WarpCodegenError(
f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
)
return arg_values
else:
return adj.resolve_static_expression(kw.value)[0]
kwds = {kw.arg: kwval(kw) for kw in node.keywords}
# get expected return count, e.g.: for multi-assignment
min_outputs = None
if hasattr(node, "expects"):
min_outputs = node.expects
# add var with value type from the function
out = adj.add_call(func=func, args=args, kwds=kwds, templates=templates, min_outputs=min_outputs)
return out
def emit_Index(adj, node):
# the ast.Index node appears in 3.7 versions
# when performing array slices, e.g.: x = arr[i]
# but in version 3.8 and higher it does not appear
if hasattr(node, "is_adjoint"):
node.value.is_adjoint = True
return adj.eval(node.value)
def emit_Subscript(adj, node):
if hasattr(node.value, "attr") and node.value.attr == "adjoint":
# handle adjoint of a variable, i.e. wp.adjoint[var]
node.slice.is_adjoint = True
var = adj.eval(node.slice)
var_name = var.label
var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
return var
target = adj.eval(node.value)
indices = []
if isinstance(node.slice, ast.Tuple):
# handles the x[i,j] case (Python 3.8.x upward)
for arg in node.slice.elts:
var = adj.eval(arg)
indices.append(var)
elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
# handles the x[i,j] case (Python 3.7.x)
for arg in node.slice.value.elts:
var = adj.eval(arg)
indices.append(var)
else:
# simple expression, e.g.: x[i]
var = adj.eval(node.slice)
indices.append(var)
target_type = strip_reference(target.type)
if is_array(target_type):
if len(indices) == target_type.ndim:
# handles array loads (where each dimension has an index specified)
out = adj.add_builtin_call("address", [target, *indices])
else:
# handles array views (fewer indices than dimensions)
out = adj.add_builtin_call("view", [target, *indices])
else:
# handles non-array type indexing, e.g: vec3, mat33, etc
out = adj.add_builtin_call("extract", [target, *indices])
return out
def emit_Assign(adj, node):
if len(node.targets) != 1:
raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
lhs = node.targets[0]
# handle the case where we are assigning multiple output variables
if isinstance(lhs, ast.Tuple):
# record the expected number of outputs on the node
# we do this so we can decide which function to
# call based on the number of expected outputs
if isinstance(node.value, ast.Call):
node.value.expects = len(lhs.elts)
# evaluate values
if isinstance(node.value, ast.Tuple):
out = [adj.eval(v) for v in node.value.elts]
else:
out = adj.eval(node.value)
names = []
for v in lhs.elts:
if isinstance(v, ast.Name):
names.append(v.id)
else:
raise WarpCodegenError(
"Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
)
if len(names) != len(out):
raise WarpCodegenError(
f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
)
for name, rhs in zip(names, out):
if name in adj.symbols:
if not types_equal(rhs.type, adj.symbols[name].type):
raise WarpCodegenTypeError(
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
)
adj.symbols[name] = rhs
# handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
elif isinstance(lhs, ast.Subscript):
if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
# handle adjoint of a variable, i.e. wp.adjoint[var]
lhs.slice.is_adjoint = True
src_var = adj.eval(lhs.slice)
var = Var(f"adj_{src_var.label}", type=src_var.type, constant=None, prefix=False)
value = adj.eval(node.value)
adj.add_forward(f"{var.emit()} = {value.emit()};")
return
target = adj.eval(lhs.value)
value = adj.eval(node.value)
slice = lhs.slice
indices = []
if isinstance(slice, ast.Tuple):
# handles the x[i, j] case (Python 3.8.x upward)
for arg in slice.elts:
var = adj.eval(arg)
indices.append(var)
elif isinstance(slice, ast.Index) and isinstance(slice.value, ast.Tuple):
# handles the x[i, j] case (Python 3.7.x)
for arg in slice.value.elts:
var = adj.eval(arg)
indices.append(var)
else:
# simple expression, e.g.: x[i]
var = adj.eval(slice)
indices.append(var)
target_type = strip_reference(target.type)
if is_array(target_type):
adj.add_builtin_call("array_store", [target, *indices, value])
elif type_is_vector(target_type) or type_is_matrix(target_type):
if is_reference(target.type):
attr = adj.add_builtin_call("indexref", [target, *indices])
else:
attr = adj.add_builtin_call("index", [target, *indices])
adj.add_builtin_call("store", [attr, value])
if warp.config.verbose and not adj.custom_reverse_mode:
lineno = adj.lineno + adj.fun_lineno
line = adj.source_lines[adj.lineno]
node_source = adj.get_node_source(lhs.value)
print(
f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
)
else:
raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
elif isinstance(lhs, ast.Name):
# symbol name
name = lhs.id
# evaluate rhs
rhs = adj.eval(node.value)
# check type matches if symbol already defined
if name in adj.symbols:
if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
raise WarpCodegenTypeError(
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
)
# handle simple assignment case (a = b), where we generate a value copy rather than reference
if isinstance(node.value, ast.Name) or is_reference(rhs.type):
out = adj.add_builtin_call("copy", [rhs])
else:
out = rhs
# update symbol map (assumes lhs is a Name node)
adj.symbols[name] = out
elif isinstance(lhs, ast.Attribute):
rhs = adj.eval(node.value)
aggregate = adj.eval(lhs.value)
aggregate_type = strip_reference(aggregate.type)
# assigning to a vector component
if type_is_vector(aggregate_type):
index = adj.vector_component_index(lhs.attr, aggregate_type)
if is_reference(aggregate.type):
attr = adj.add_builtin_call("indexref", [aggregate, index])
else:
attr = adj.add_builtin_call("index", [aggregate, index])
adj.add_builtin_call("store", [attr, rhs])
else:
attr = adj.emit_Attribute(lhs)
if is_reference(attr.type):
adj.add_builtin_call("store", [attr, rhs])
else:
adj.add_builtin_call("assign", [attr, rhs])
if warp.config.verbose and not adj.custom_reverse_mode:
lineno = adj.lineno + adj.fun_lineno
line = adj.source_lines[adj.lineno]
msg = f'Warning: detected mutated struct {attr.label} during function "{adj.fun_name}" at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n'
print(msg)
else:
raise WarpCodegenError("Error, unsupported assignment statement.")
def emit_Return(adj, node):
if node.value is None:
var = None
elif isinstance(node.value, ast.Tuple):
var = tuple(adj.eval(arg) for arg in node.value.elts)
else:
var = (adj.eval(node.value),)
if adj.return_var is not None:
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
new_ctypes = tuple(v.ctype(value_type=True) for v in var)
if old_ctypes != new_ctypes:
raise WarpCodegenTypeError(
f"Error, function returned different types, previous: [{', '.join(old_ctypes)}], new [{', '.join(new_ctypes)}]"
)
if var is not None:
adj.return_var = tuple()
for ret in var:
if is_reference(ret.type):
ret = adj.add_builtin_call("copy", [ret])
adj.return_var += (ret,)
adj.add_return(adj.return_var)
def emit_AugAssign(adj, node):
# replace augmented assignment with assignment statement + binary op
new_node = ast.Assign(targets=[node.target], value=ast.BinOp(node.target, node.op, node.value))
adj.eval(new_node)
def emit_Tuple(adj, node):
# LHS for expressions, such as i, j, k = 1, 2, 3
for elem in node.elts:
adj.eval(elem)
def emit_Pass(adj, node):
pass
node_visitors = {
ast.FunctionDef: emit_FunctionDef,
ast.If: emit_If,
ast.Compare: emit_Compare,
ast.BoolOp: emit_BoolOp,
ast.Name: emit_Name,
ast.Attribute: emit_Attribute,
ast.Str: emit_String, # Deprecated in 3.8; use Constant
ast.Num: emit_Num, # Deprecated in 3.8; use Constant
ast.NameConstant: emit_NameConstant, # Deprecated in 3.8; use Constant
ast.Constant: emit_Constant,
ast.BinOp: emit_BinOp,
ast.UnaryOp: emit_UnaryOp,
ast.While: emit_While,
ast.For: emit_For,
ast.Break: emit_Break,
ast.Continue: emit_Continue,
ast.Expr: emit_Expr,
ast.Call: emit_Call,
ast.Index: emit_Index, # Deprecated in 3.8; Use the index value directly instead.
ast.Subscript: emit_Subscript,
ast.Assign: emit_Assign,
ast.Return: emit_Return,
ast.AugAssign: emit_AugAssign,
ast.Tuple: emit_Tuple,
ast.Pass: emit_Pass,
ast.Ellipsis: emit_Ellipsis,
}
def eval(adj, node):
if hasattr(node, "lineno"):
adj.set_lineno(node.lineno - 1)
emit_node = adj.node_visitors[type(node)]
return emit_node(adj, node)
# helper to evaluate expressions of the form
# obj1.obj2.obj3.attr in the function's global scope
def resolve_path(adj, path):
if len(path) == 0:
return None
# if root is overshadowed by local symbols, bail out
if path[0] in adj.symbols:
return None
if path[0] in __builtins__:
return __builtins__[path[0]]
# Look up the closure info and append it to adj.func.__globals__
# in case you want to define a kernel inside a function and refer
# to variables you've declared inside that function:
extract_contents = (
lambda contents: contents
if isinstance(contents, warp.context.Function) or not callable(contents)
else contents
)
capturedvars = dict(
zip(
adj.func.__code__.co_freevars,
[extract_contents(c.cell_contents) for c in (adj.func.__closure__ or [])],
)
)
vars_dict = {**adj.func.__globals__, **capturedvars}
if path[0] in vars_dict:
func = vars_dict[path[0]]
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
else:
func = getattr(warp, path[0], None)
if func:
for i in range(1, len(path)):
if hasattr(func, path[i]):
func = getattr(func, path[i])
return func
# Evaluates a static expression that does not depend on runtime values
# if eval_types is True, try resolving the path using evaluated type information as well
def resolve_static_expression(adj, root_node, eval_types=True):
attributes = []
node = root_node
while isinstance(node, ast.Attribute):
attributes.append(node.attr)
node = node.value
if eval_types and isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
# support for operators returning modules
# i.e. operator_name(*operator_args).x.y.z
operator_args = node.args
operator_name = node.func.id
if operator_name == "type":
if len(operator_args) != 1:
raise WarpCodegenError(f"type() operator expects exactly one argument, got {len(operator_args)}")
# type() operator
var = adj.eval(operator_args[0])
if isinstance(var, Var):
var_type = strip_reference(var.type)
# Allow accessing type attributes, for instance array.dtype
while attributes:
attr_name = attributes.pop()
var_type, prev_type = adj.resolve_type_attribute(var_type, attr_name), var_type
if var_type is None:
raise WarpCodegenAttributeError(
f"{attr_name} is not an attribute of {type_repr(prev_type)}"
)
return var_type, [type_repr(var_type)]
else:
raise WarpCodegenError(f"Cannot deduce the type of {var}")
# reverse list since ast presents it backward order
path = [*reversed(attributes)]
if isinstance(node, ast.Name):
path.insert(0, node.id)
# Try resolving path from captured context
captured_obj = adj.resolve_path(path)
if captured_obj is not None:
return captured_obj, path
# Still nothing found, maybe this is a predefined type attribute like `dtype`
if eval_types:
try:
val = adj.eval(root_node)
if val:
return [val, type_repr(val)]
except Exception:
pass
return None, path
# annotate generated code with the original source code line
def set_lineno(adj, lineno):
if adj.lineno is None or adj.lineno != lineno:
line = lineno + adj.fun_lineno
source = adj.source_lines[lineno].strip().ljust(80 - len(adj.indentation), " ")
adj.add_forward(f"// {source} <L {line}>")
adj.add_reverse(f"// adj: {source} <L {line}>")
adj.lineno = lineno
def get_node_source(adj, node):
# return the Python code corresponding to the given AST node
return ast.get_source_segment(adj.source, node)
# ----------------
# code generation
cpu_module_header = """
#define WP_NO_CRT
#include "builtin.h"
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
#define float(x) cast_float(x)
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
#define int(x) cast_int(x)
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
#define builtin_tid1d() wp::tid(wp::s_threadIdx)
#define builtin_tid2d(x, y) wp::tid(x, y, wp::s_threadIdx, dim)
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, wp::s_threadIdx, dim)
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, wp::s_threadIdx, dim)
"""
cuda_module_header = """
#define WP_NO_CRT
#include "builtin.h"
// avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
#define float(x) cast_float(x)
#define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
#define int(x) cast_int(x)
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
#define builtin_tid1d() wp::tid(_idx)
#define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
"""
struct_template = """
struct {name}
{{
{struct_body}
CUDA_CALLABLE {name}({forward_args})
{forward_initializers}
{{
}}
CUDA_CALLABLE {name}& operator += (const {name}& rhs)
{{{prefix_add_body}
return *this;}}
}};
static CUDA_CALLABLE void adj_{name}({reverse_args})
{{
{reverse_body}}}
CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
{{
{atomic_add_body}}}
"""
cpu_forward_function_template = """
// {filename}:{lineno}
static {return_type} {name}(
{forward_args})
{{
{forward_body}}}
"""
cpu_reverse_function_template = """
// {filename}:{lineno}
static void adj_{name}(
{reverse_args})
{{
{reverse_body}}}
"""
cuda_forward_function_template = """
// {filename}:{lineno}
static CUDA_CALLABLE {return_type} {name}(
{forward_args})
{{
{forward_body}}}
"""
cuda_reverse_function_template = """
// {filename}:{lineno}
static CUDA_CALLABLE void adj_{name}(
{reverse_args})
{{
{reverse_body}}}
"""
cuda_kernel_template = """
extern "C" __global__ void {name}_cuda_kernel_forward(
{forward_args})
{{
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
_idx < dim.size;
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x)) {{
{forward_body}}}}}
extern "C" __global__ void {name}_cuda_kernel_backward(
{reverse_args})
{{
for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
_idx < dim.size;
_idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x)) {{
{reverse_body}}}}}
"""
cpu_kernel_template = """
void {name}_cpu_kernel_forward(
{forward_args})
{{
{forward_body}}}
void {name}_cpu_kernel_backward(
{reverse_args})
{{
{reverse_body}}}
"""
cpu_module_template = """
extern "C" {{
// Python CPU entry points
WP_API void {name}_cpu_forward(
{forward_args})
{{
for (size_t i=0; i < dim.size; ++i)
{{
wp::s_threadIdx = i;
{name}_cpu_kernel_forward(
{forward_params});
}}
}}
WP_API void {name}_cpu_backward(
{reverse_args})
{{
for (size_t i=0; i < dim.size; ++i)
{{
wp::s_threadIdx = i;
{name}_cpu_kernel_backward(
{reverse_params});
}}
}}
}} // extern C
"""
cuda_module_header_template = """
extern "C" {{
// Python CUDA entry points
WP_API void {name}_cuda_forward(
void* stream,
{forward_args});
WP_API void {name}_cuda_backward(
void* stream,
{reverse_args});
}} // extern C
"""
cpu_module_header_template = """
extern "C" {{
// Python CPU entry points
WP_API void {name}_cpu_forward(
{forward_args});
WP_API void {name}_cpu_backward(
{reverse_args});
}} // extern C
"""
# converts a constant Python value to equivalent C-repr
def constant_str(value):
value_type = type(value)
if value_type == bool or value_type == builtins.bool:
if value:
return "true"
else:
return "false"
elif value_type == str:
# ensure constant strings are correctly escaped
return '"' + str(value.encode("unicode-escape").decode()) + '"'
elif isinstance(value, ctypes.Array):
if value_type._wp_scalar_type_ == float16:
# special case for float16, which is stored as uint16 in the ctypes.Array
from warp.context import runtime
scalar_value = runtime.core.half_bits_to_float
else:
scalar_value = lambda x: x
# list of scalar initializer values
initlist = []
for i in range(value._length_):
x = ctypes.Array.__getitem__(value, i)
initlist.append(str(scalar_value(x)))
dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
# construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
return f"{dtypestr}{{{', '.join(initlist)}}}"
elif value_type in warp.types.scalar_types:
# make sure we emit the value of objects, e.g. uint32
return str(value.value)
else:
# otherwise just convert constant to string
return str(value)
def indent(args, stops=1):
sep = ",\n"
for i in range(stops):
sep += " "
# return sep + args.replace(", ", "," + sep)
return sep.join(args)
# generates a C function name based on the python function name
def make_full_qualified_name(func):
if not isinstance(func, str):
func = func.__qualname__
return re.sub("[^0-9a-zA-Z_]+", "", func.replace(".", "__"))
def codegen_struct(struct, device="cpu", indent_size=4):
name = make_full_qualified_name(struct.cls)
body = []
indent_block = " " * indent_size
if len(struct.vars) > 0:
for label, var in struct.vars.items():
body.append(var.ctype() + " " + label + ";\n")
else:
# for empty structs, emit the dummy attribute to avoid any compiler-specific alignment issues
body.append("char _dummy_;\n")
forward_args = []
reverse_args = []
forward_initializers = []
reverse_body = []
atomic_add_body = []
prefix_add_body = []
# forward args
for label, var in struct.vars.items():
var_ctype = var.ctype()
forward_args.append(f"{var_ctype} const& {label} = {{}}")
reverse_args.append(f"{var_ctype} const&")
namespace = "wp::" if var_ctype.startswith("wp::") or var_ctype == "bool" else ""
atomic_add_body.append(f"{indent_block}{namespace}adj_atomic_add(&p->{label}, t.{label});\n")
prefix = f"{indent_block}," if forward_initializers else ":"
forward_initializers.append(f"{indent_block}{prefix} {label}{{{label}}}\n")
# prefix-add operator
for label, var in struct.vars.items():
if not is_array(var.type):
prefix_add_body.append(f"{indent_block}{label} += rhs.{label};\n")
# reverse args
for label, var in struct.vars.items():
reverse_args.append(var.ctype() + " & adj_" + label)
if is_array(var.type):
reverse_body.append(f"{indent_block}adj_{label} = adj_ret.{label};\n")
else:
reverse_body.append(f"{indent_block}adj_{label} += adj_ret.{label};\n")
reverse_args.append(name + " & adj_ret")
return struct_template.format(
name=name,
struct_body="".join([indent_block + l for l in body]),
forward_args=indent(forward_args),
forward_initializers="".join(forward_initializers),
reverse_args=indent(reverse_args),
reverse_body="".join(reverse_body),
prefix_add_body="".join(prefix_add_body),
atomic_add_body="".join(atomic_add_body),
)
def codegen_func_forward_body(adj, device="cpu", indent=4):
body = []
indent_block = " " * indent
for f in adj.blocks[0].body_forward:
body += [f + "\n"]
return "".join([indent_block + l for l in body])
def codegen_func_forward(adj, func_type="kernel", device="cpu"):
s = ""
# primal vars
s += " //---------\n"
s += " // primal vars\n"
for var in adj.variables:
if var.constant is None:
s += f" {var.ctype()} {var.emit()};\n"
else:
s += f" const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"
# forward pass
s += " //---------\n"
s += " // forward\n"
if device == "cpu":
s += codegen_func_forward_body(adj, device=device, indent=4)
elif device == "cuda":
if func_type == "kernel":
s += codegen_func_forward_body(adj, device=device, indent=8)
else:
s += codegen_func_forward_body(adj, device=device, indent=4)
return s
def codegen_func_reverse_body(adj, device="cpu", indent=4, func_type="kernel"):
body = []
indent_block = " " * indent
# forward pass
body += ["//---------\n"]
body += ["// forward\n"]
for f in adj.blocks[0].body_replay:
body += [f + "\n"]
# reverse pass
body += ["//---------\n"]
body += ["// reverse\n"]
for l in reversed(adj.blocks[0].body_reverse):
body += [l + "\n"]
# In grid-stride kernels the reverse body is in a for loop
if device == "cuda" and func_type == "kernel":
body += ["continue;\n"]
else:
body += ["return;\n"]
return "".join([indent_block + l for l in body])
def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
s = ""
# primal vars
s += " //---------\n"
s += " // primal vars\n"
for var in adj.variables:
if var.constant is None:
s += f" {var.ctype()} {var.emit()};\n"
else:
s += f" const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"
# dual vars
s += " //---------\n"
s += " // dual vars\n"
for var in adj.variables:
s += f" {var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"
if device == "cpu":
s += codegen_func_reverse_body(adj, device=device, indent=4)
elif device == "cuda":
if func_type == "kernel":
s += codegen_func_reverse_body(adj, device=device, indent=8, func_type=func_type)
else:
s += codegen_func_reverse_body(adj, device=device, indent=4, func_type=func_type)
else:
raise ValueError(f"Device {device} not supported for codegen")
return s
def codegen_func(adj, c_func_name: str, device="cpu", options={}):
# forward header
if adj.return_var is not None and len(adj.return_var) == 1:
return_type = adj.return_var[0].ctype()
else:
return_type = "void"
has_multiple_outputs = adj.return_var is not None and len(adj.return_var) != 1
forward_args = []
reverse_args = []
# forward args
for i, arg in enumerate(adj.args):
s = f"{arg.ctype()} {arg.emit()}"
forward_args.append(s)
if not adj.custom_reverse_mode or i < adj.custom_reverse_num_input_args:
reverse_args.append(s)
if has_multiple_outputs:
for i, arg in enumerate(adj.return_var):
forward_args.append(arg.ctype() + " & ret_" + str(i))
reverse_args.append(arg.ctype() + " & ret_" + str(i))
# reverse args
for i, arg in enumerate(adj.args):
if adj.custom_reverse_mode and i >= adj.custom_reverse_num_input_args:
break
# indexed array gradients are regular arrays
if isinstance(arg.type, indexedarray):
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
else:
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
if has_multiple_outputs:
for i, arg in enumerate(adj.return_var):
reverse_args.append(arg.ctype() + " & adj_ret_" + str(i))
elif return_type != "void":
reverse_args.append(return_type + " & adj_ret")
# custom output reverse args (user-declared)
if adj.custom_reverse_mode:
for arg in adj.args[adj.custom_reverse_num_input_args :]:
reverse_args.append(f"{arg.ctype()} & {arg.emit()}")
if device == "cpu":
forward_template = cpu_forward_function_template
reverse_template = cpu_reverse_function_template
elif device == "cuda":
forward_template = cuda_forward_function_template
reverse_template = cuda_reverse_function_template
else:
raise ValueError(f"Device {device} is not supported")
# codegen body
forward_body = codegen_func_forward(adj, func_type="function", device=device)
s = ""
if not adj.skip_forward_codegen:
s += forward_template.format(
name=c_func_name,
return_type=return_type,
forward_args=indent(forward_args),
forward_body=forward_body,
filename=adj.filename,
lineno=adj.fun_lineno,
)
if not adj.skip_reverse_codegen:
if adj.custom_reverse_mode:
reverse_body = "\t// user-defined adjoint code\n" + forward_body
else:
if options.get("enable_backward", True):
reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
else:
reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
s += reverse_template.format(
name=c_func_name,
return_type=return_type,
reverse_args=indent(reverse_args),
forward_body=forward_body,
reverse_body=reverse_body,
filename=adj.filename,
lineno=adj.fun_lineno,
)
return s
def codegen_snippet(adj, name, snippet, adj_snippet):
forward_args = []
reverse_args = []
# forward args
for i, arg in enumerate(adj.args):
s = f"{arg.ctype()} {arg.emit().replace('var_', '')}"
forward_args.append(s)
reverse_args.append(s)
# reverse args
for i, arg in enumerate(adj.args):
if isinstance(arg.type, indexedarray):
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
reverse_args.append(_arg.ctype() + " & adj_" + arg.label)
else:
reverse_args.append(arg.ctype() + " & adj_" + arg.label)
forward_template = cuda_forward_function_template
reverse_template = cuda_reverse_function_template
s = ""
s += forward_template.format(
name=name,
return_type="void",
forward_args=indent(forward_args),
forward_body=snippet,
filename=adj.filename,
lineno=adj.fun_lineno,
)
if adj_snippet:
reverse_body = adj_snippet
else:
reverse_body = ""
s += reverse_template.format(
name=name,
return_type="void",
reverse_args=indent(reverse_args),
forward_body=snippet,
reverse_body=reverse_body,
filename=adj.filename,
lineno=adj.fun_lineno,
)
return s
def codegen_kernel(kernel, device, options):
# Update the module's options with the ones defined on the kernel, if any.
options = dict(options)
options.update(kernel.options)
adj = kernel.adj
forward_args = ["wp::launch_bounds_t dim"]
reverse_args = ["wp::launch_bounds_t dim"]
# forward args
for arg in adj.args:
forward_args.append(arg.ctype() + " var_" + arg.label)
reverse_args.append(arg.ctype() + " var_" + arg.label)
# reverse args
for arg in adj.args:
# indexed array gradients are regular arrays
if isinstance(arg.type, indexedarray):
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
reverse_args.append(_arg.ctype() + " adj_" + arg.label)
else:
reverse_args.append(arg.ctype() + " adj_" + arg.label)
# codegen body
forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
if options["enable_backward"]:
reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
else:
reverse_body = ""
if device == "cpu":
template = cpu_kernel_template
elif device == "cuda":
template = cuda_kernel_template
else:
raise ValueError(f"Device {device} is not supported")
s = template.format(
name=kernel.get_mangled_name(),
forward_args=indent(forward_args),
reverse_args=indent(reverse_args),
forward_body=forward_body,
reverse_body=reverse_body,
)
return s
def codegen_module(kernel, device="cpu"):
if device != "cpu":
return ""
adj = kernel.adj
# build forward signature
forward_args = ["wp::launch_bounds_t dim"]
forward_params = ["dim"]
for arg in adj.args:
if hasattr(arg.type, "_wp_generic_type_str_"):
# vectors and matrices are passed from Python by pointer
forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
forward_params.append(f"*var_{arg.label}")
else:
forward_args.append(f"{arg.ctype()} var_{arg.label}")
forward_params.append("var_" + arg.label)
# build reverse signature
reverse_args = [*forward_args]
reverse_params = [*forward_params]
for arg in adj.args:
if isinstance(arg.type, indexedarray):
# indexed array gradients are regular arrays
_arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
reverse_params.append(f"adj_{_arg.label}")
elif hasattr(arg.type, "_wp_generic_type_str_"):
# vectors and matrices are passed from Python by pointer
reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
reverse_params.append(f"*adj_{arg.label}")
else:
reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
reverse_params.append(f"adj_{arg.label}")
s = cpu_module_template.format(
name=kernel.get_mangled_name(),
forward_args=indent(forward_args),
reverse_args=indent(reverse_args),
forward_params=indent(forward_params, 3),
reverse_params=indent(reverse_params, 3),
)
return s