| import inspect |
| import re |
| import textwrap |
| import types |
| import triton |
|
|
|
|
| def cacheable(f): |
| """ |
| A decorator that allow you to write something of the form: |
| |
| @cacheable |
| def my_kernel(): return (expression dynamically defining a kernel) |
| |
| such that it interacts gracefully with triton cache and preload. |
| """ |
|
|
| g = f() |
| g.fn.__name__ = f.__name__ |
| g.fn.__module__ = f.__module__ |
| g.fn.__qualname__ = f.__qualname__ |
| g._fn_name = f"{f.__module__}.{f.__qualname__}" |
| return g |
|
|
|
|
| def define_kernel(src, module, attrs=None, **extra_globals): |
| """ |
| Dynamically create a Triton function or kernel from a src string, |
| linking any symbols in the kernel to objects specified by extra_globals. |
| """ |
|
|
| |
| def _empty_fn(): |
| pass |
|
|
| gdict = dict(**(_empty_fn.__globals__)) |
| gdict.update(extra_globals) |
| f = types.FunctionType(_empty_fn.__code__, gdict) |
| f.__module__ = module.__name__ |
|
|
| src = textwrap.dedent(src) |
| src = src[src.find("def "):] |
|
|
| stored_functions = [] |
| function_name = src[4:].split("(")[0].strip() |
|
|
| exec_globals = gdict |
| exec_globals.update({"stored_functions": stored_functions}) |
| exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals) |
|
|
| f.__signature__ = inspect.signature(stored_functions[0]) |
| f.__name__ = function_name |
| f.__doc__ = stored_functions[0].__doc__ |
|
|
| if attrs is None: |
| attrs = dict() |
| f = triton.JITFunction(f, **attrs) |
| f._unsafe_update_src(src) |
| return f |
|
|
|
|
| def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()): |
| assert isinstance(fn, triton.runtime.jit.JITFunction) |
| if name is None: |
| name = f"{fn.__name__}" |
| |
| src = inspect.getsource(fn.fn) |
| src = textwrap.dedent(src) |
| lines = src.split("\n") |
| |
| def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def")) |
| |
| header_end = def_idx |
| while not lines[header_end].rstrip().endswith(":"): |
| header_end += 1 |
| body_lines = lines[header_end + 1:] |
| header_lines = lines[def_idx:header_end + 1] |
| |
| header_clean = [ |
| l.split("#", 1)[0].strip() |
| for l in header_lines |
| if l.split("#", 1)[0].strip() |
| ] |
| |
| header_src = " ".join(header_clean) |
| m = re.search(r"\((.*)\)\s*:", header_src) |
| if not m: |
| raise ValueError("Could not parse function header") |
| args_str = m.group(1) |
| args = [arg.strip() for arg in args_str.split(",") if arg.strip()] |
| non_specialized_args = [] |
| for arg in args: |
| arg_key = arg.split(":")[0].split("=")[0].strip() |
| new_args = tuples.get(arg_key, [arg]) |
| if arg_key not in constants: |
| non_specialized_args += new_args |
| |
| spec_fns = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)} |
| globals = spec_fns | fn.get_capture_scope() |
| |
| new_signature = f"def {name}({', '.join(non_specialized_args)}):" |
| constexpr_lines = [ |
| f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items() |
| ] |
| tuple_lines = [ |
| f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items() |
| ] |
| new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines) |
| |
| sig = inspect.signature(triton.runtime.jit.JITFunction.__init__) |
| params = list(sig.parameters.values())[2:] |
| attrs = {param.name: getattr(fn, param.name, param.default) for param in params} |
|
|
| |
| base_repr = attrs["repr"] |
|
|
| def new_repr(specialization): |
| ret = base_repr(specialization) |
| for spec_fn in spec_fns.values(): |
| spec_repr = spec_fn.repr(None) |
| if spec_repr: |
| spec_repr = spec_repr.strip("_") |
| if spec_repr: |
| ret += f"_{spec_repr}" |
| return ret |
|
|
| attrs["repr"] = new_repr |
|
|
| if do_not_specialize: |
| attrs["do_not_specialize"] = do_not_specialize |
| ret = define_kernel(new_src, module, attrs, **globals) |
| return ret |
|
|