| """Convenience functions built on top of `make_vjp`.""" |
|
|
| from collections import OrderedDict |
|
|
| try: |
| from inspect import getfullargspec as _getargspec |
| except ImportError: |
| from inspect import getargspec as _getargspec |
| import warnings |
|
|
| import autograd.numpy as np |
|
|
| from .builtins import tuple as atuple |
| from .core import make_jvp as _make_jvp |
| from .core import make_vjp as _make_vjp |
| from .extend import defvjp_argnum, primitive, vspace |
| from .wrap_util import unary_to_nary |
|
|
| make_vjp = unary_to_nary(_make_vjp) |
| make_jvp = unary_to_nary(_make_jvp) |
|
|
|
|
| @unary_to_nary |
| def grad(fun, x): |
| """ |
| Returns a function which computes the gradient of `fun` with respect to |
| positional argument number `argnum`. The returned function takes the same |
| arguments as `fun`, but returns the gradient instead. The function `fun` |
| should be scalar-valued. The gradient has the same type as the argument.""" |
| vjp, ans = _make_vjp(fun, x) |
| if not vspace(ans).size == 1: |
| raise TypeError( |
| "Grad only applies to real scalar-output functions. " |
| "Try jacobian, elementwise_grad or holomorphic_grad." |
| ) |
| return vjp(vspace(ans).ones()) |
|
|
|
|
| @unary_to_nary |
| def elementwise_grad(fun, x): |
| """ |
| Returns a function that computes the sum of each column of the Jacobian of |
| `fun`, in one pass. If the Jacobian is diagonal, then this is the diagonal |
| of the Jacobian. |
| """ |
| vjp, ans = _make_vjp(fun, x) |
| if vspace(ans).iscomplex: |
| raise TypeError("Elementwise_grad only applies to real-output functions.") |
| return vjp(vspace(ans).ones()) |
|
|
|
|
| @unary_to_nary |
| def deriv(fun, x): |
| return _make_jvp(fun, x)(vspace(x).ones())[1] |
|
|
|
|
| @unary_to_nary |
| def jacobian(fun, x): |
| """ |
| Returns a function which computes the Jacobian of `fun` with respect to |
| positional argument number `argnum`, which must be a scalar or array. Unlike |
| `grad` it is not restricted to scalar-output functions, but also it cannot |
| take derivatives with respect to some argument types (like lists or dicts). |
| If the input to `fun` has shape (in1, in2, ...) and the output has shape |
| (out1, out2, ...) then the Jacobian has shape (out1, out2, ..., in1, in2, ...). |
| """ |
| vjp, ans = _make_vjp(fun, x) |
| ans_vspace = vspace(ans) |
| jacobian_shape = ans_vspace.shape + vspace(x).shape |
| grads = map(vjp, ans_vspace.standard_basis()) |
| return np.reshape(np.stack(grads), jacobian_shape) |
|
|
|
|
| @unary_to_nary |
| def holomorphic_grad(fun, x): |
| if not vspace(x).iscomplex: |
| warnings.warn("Input to holomorphic_grad is not complex") |
| return grad(lambda x: np.real(fun(x)))(x) |
|
|
|
|
| def grad_named(fun, argname): |
| """Takes gradients with respect to a named argument. |
| Doesn't work on *args or **kwargs.""" |
| arg_index = _getargspec(fun).args.index(argname) |
| return grad(fun, arg_index) |
|
|
|
|
| @unary_to_nary |
| def hessian(fun, x): |
| "Returns a function that computes the exact Hessian." |
| return jacobian(jacobian(fun))(x) |
|
|
|
|
| @unary_to_nary |
| def make_hvp(fun, x): |
| """Builds a function for evaluating the Hessian-vector product at a point, |
| which may be useful when evaluating many Hessian-vector products at the same |
| point while caching the results of the forward pass.""" |
| return _make_vjp(grad(fun), x) |
|
|
|
|
| def hessian_tensor_product(fun, argnum=0): |
| """Builds a function that returns the exact Hessian-tensor product. |
| The returned function has arguments (*args, tensor, **kwargs), and for |
| vectors takes roughly 4x as long to evaluate as the original function.""" |
| fun_grad = grad(fun, argnum) |
|
|
| def vector_dot_grad(*args, **kwargs): |
| args, vector = args[:-1], args[-1] |
| return np.tensordot(fun_grad(*args, **kwargs), vector, np.ndim(vector)) |
|
|
| return grad(vector_dot_grad, argnum) |
|
|
|
|
| hessian_vector_product = hessian_tensor_product |
|
|
|
|
| def tensor_jacobian_product(fun, argnum=0): |
| """Builds a function that returns the exact tensor-Jacobian product, that |
| is the Jacobian matrix left-multiplied by tensor. The returned function |
| has arguments (*args, tensor, **kwargs).""" |
|
|
| def vector_dot_fun(*args, **kwargs): |
| args, vector = args[:-1], args[-1] |
| return np.tensordot(vector, fun(*args, **kwargs), axes=np.ndim(vector)) |
|
|
| return jacobian(vector_dot_fun, argnum) |
|
|
|
|
| vector_jacobian_product = tensor_jacobian_product |
|
|
|
|
| @unary_to_nary |
| def make_jvp_reversemode(fun, x): |
| """Builds a function for evaluating the Jacobian-vector product at a |
| point. Roughly 1.5x more FLOPs than forward-mode, plus memory requirements |
| that scale with the number of primitives applied in the evaluation of f, as |
| well as other overheads. See j-towns.github.io/2017/06/12/A-new-trick.html.""" |
| vjp, y = _make_vjp(fun, x) |
| vjp_vjp, _ = _make_vjp(vjp, vspace(y).zeros()) |
| return vjp_vjp |
|
|
|
|
| |
| def make_ggnvp(f, g=lambda x: 1.0 / 2 * np.sum(x**2, axis=-1), f_argnum=0): |
| """Builds a function for evaluating generalized-Gauss-Newton-vector products |
| at a point. Slightly more expensive than mixed-mode.""" |
|
|
| @unary_to_nary |
| def _make_ggnvp(f, x): |
| f_vjp, f_x = _make_vjp(f, x) |
| g_hvp, grad_g_x = _make_vjp(grad(g), f_x) |
| f_jvp, _ = _make_vjp(f_vjp, vspace(grad_g_x).zeros()) |
|
|
| def ggnvp(v): |
| return f_vjp(g_hvp(f_jvp(v))) |
|
|
| return ggnvp |
|
|
| return _make_ggnvp(f, f_argnum) |
|
|
|
|
| @unary_to_nary |
| def value_and_grad(fun, x): |
| """Returns a function that returns both value and gradient. Suitable for use |
| in scipy.optimize""" |
| vjp, ans = _make_vjp(fun, x) |
| if not vspace(ans).size == 1: |
| raise TypeError( |
| "value_and_grad only applies to real scalar-output " |
| "functions. Try jacobian, elementwise_grad or " |
| "holomorphic_grad." |
| ) |
| return ans, vjp(vspace(ans).ones()) |
|
|
|
|
| @unary_to_nary |
| def grad_and_aux(fun, x): |
| """Builds a function that returns the gradient of the first output and the |
| (unmodified) second output of a function that returns two outputs.""" |
| vjp, (ans, aux) = _make_vjp(lambda x: atuple(fun(x)), x) |
| return vjp((vspace(ans).ones(), vspace(aux).zeros())), aux |
|
|
|
|
| def multigrad_dict(fun): |
| "Takes gradients wrt all arguments simultaneously," |
| "returns a dict mapping 'argname' to 'gradval'" |
|
|
| import funcsigs |
|
|
| sig = funcsigs.signature(fun) |
|
|
| def select(preds, lst): |
| idx = lambda item: next((i for i, pred in enumerate(preds) if pred(item)), len(preds)) |
| results = [[] for _ in preds] + [[]] |
| for item in lst: |
| results[idx(item)].append(item) |
| return results |
|
|
| is_var_pos = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_POSITIONAL |
| is_var_kwd = lambda name: sig.parameters[name].kind == sig.parameters[name].VAR_KEYWORD |
| var_pos, var_kwd, argnames = select([is_var_pos, is_var_kwd], sig.parameters) |
|
|
| todict = lambda dct: {key: dct[key] for key in dct} |
|
|
| def apply_defaults(arguments): |
| defaults = { |
| name: param.default for name, param in sig.parameters.items() if param.default is not param.empty |
| } |
| return OrderedDict( |
| (name, arguments[name] if name in arguments else defaults[name]) for name in sig.parameters |
| ) |
|
|
| def gradfun(*args, **kwargs): |
| bindings = sig.bind(*args, **kwargs) |
|
|
| args = lambda dct: tuple(dct[var_pos[0]]) if var_pos else () |
| kwargs = lambda dct: todict(dct[var_kwd[0]]) if var_kwd else {} |
| others = lambda dct: tuple(dct[argname] for argname in argnames if argname not in var_kwd + var_pos) |
|
|
| newfun = lambda dct: fun(*(others(dct) + args(dct)), **kwargs(dct)) |
|
|
| argdict = apply_defaults(bindings.arguments) |
| grad_dict = grad(newfun)(dict(argdict)) |
| return OrderedDict((argname, grad_dict[argname]) for argname in argdict) |
|
|
| return gradfun |
|
|
|
|
| def checkpoint(fun): |
| """Returns a checkpointed version of `fun`, where intermediate values |
| computed during the forward pass of `fun` are discarded and then recomputed |
| for the backward pass. Useful to save memory, effectively trading off time |
| and memory. See e.g. arxiv.org/abs/1604.06174. |
| """ |
|
|
| def wrapped_grad(argnum, ans, args, kwargs): |
| return make_vjp(fun, argnum)(*args, **kwargs)[0] |
|
|
| wrapped = primitive(fun) |
| defvjp_argnum(wrapped, wrapped_grad) |
| return wrapped |
|
|