| # Copyright © 2023 Apple Inc. | |
| from typing import Callable | |
| import mlx.core as mx | |
| def value_and_grad(model: "mlx.nn.Module", fn: Callable): | |
| """Transform the passed function ``fn`` to a function that computes the | |
| gradients of ``fn`` wrt the model's trainable parameters and also its | |
| value. | |
| Args: | |
| model (mlx.nn.Module): The model whose trainable parameters to compute | |
| gradients for | |
| fn (Callable): The scalar function to compute gradients for | |
| Returns: | |
| A callable that returns the value of ``fn`` and the gradients wrt the | |
| trainable parameters of ``model`` | |
| """ | |
| def inner_fn(params, *args, **kwargs): | |
| model.update(params) | |
| return fn(*args, **kwargs) | |
| value_grad_fn = mx.value_and_grad(inner_fn) | |
| def wrapped_value_grad_fn(*args, **kwargs): | |
| value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) | |
| return value, grad | |
| return wrapped_value_grad_fn | |