|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Helper functions for exporting JAX models to Tensorflow SavedModels.""" |
|
|
|
|
|
from typing import Any, Callable, Sequence, Optional, Union |
|
|
|
|
|
from jax.experimental import jax2tf |
|
|
import tensorflow as tf |
|
|
import tree as dm_tree |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PyTree = Any |
|
|
|
|
|
|
|
|
def convert_and_save_model( |
|
|
jax_fn: Callable[[PyTree, PyTree], PyTree], |
|
|
params: PyTree, |
|
|
model_dir: str, |
|
|
*, |
|
|
input_signatures: Union[ |
|
|
Sequence[tf.TensorSpec], |
|
|
Sequence[Sequence[tf.TensorSpec]], |
|
|
Sequence[dict[str, tf.TensorSpec]], |
|
|
], |
|
|
polymorphic_shapes: Optional[ |
|
|
Union[str, jax2tf.PolyShape, dict[str, str]] |
|
|
] = None, |
|
|
with_gradient: bool = False, |
|
|
enable_xla: bool = True, |
|
|
compile_model: bool = True, |
|
|
saved_model_options: Optional[tf.saved_model.SaveOptions] = None, |
|
|
native_serialization: Optional[str | bool] = "default", |
|
|
native_serialization_platforms: Sequence[str] | None = ("cpu", "tpu")): |
|
|
"""Converts a JAX function and saves a SavedModel. |
|
|
|
|
|
We assume that the JAX model consists of a prediction function and trained |
|
|
parameters, and the computation graph of the function is saved separately from |
|
|
the parameters. Saving the graph separately from the parameters reduces |
|
|
the size of the Tensorflow `GraphDef`, and enables finetuning of model |
|
|
parameters too. |
|
|
|
|
|
To use this function, a JAX model must be converted to a function of two |
|
|
arguments, the model parameters and the input. |
|
|
For a Scenic model, this corresponds to: |
|
|
``` |
|
|
params = train_state.optimizer.target |
|
|
flax_model = model.flax_model |
|
|
def _predict_fn(params, input_data): |
|
|
return flax_model.apply({'params': params}, input_data, train=False) |
|
|
``` |
|
|
|
|
|
Args: |
|
|
jax_fn: A JAX function taking two arguments, the parameters and the inputs. |
|
|
Both arguments may be (nested) tuples/lists/dictionaries of `np.ndarray`. |
|
|
It is necessary to be able to JIT-compile this function (ie run |
|
|
`jax.jit` on it). |
|
|
params: The parameters, to be used as first argument for `jax_fn`. These |
|
|
must be (nested) tuples/lists/dictionaries of `np.ndarray`, and will be |
|
|
saved as the variables of the SavedModel. |
|
|
model_dir: The directory where the model should be saved. |
|
|
input_signatures: The input signatures for the second argument of `jax_fn` |
|
|
(the input). A signature must be a `tensorflow.TensorSpec` instance, or a |
|
|
(nested) tuple/list/dictionary thereof with a structure matching the |
|
|
second argument of `jax_fn`. The first input_signature will be saved as |
|
|
the default serving signature. The additional signatures will be used |
|
|
only to ensure that the `jax_fn` is traced and converted to TF for the |
|
|
corresponding input shapes. |
|
|
polymorphic_shapes: If given then it will be used as the |
|
|
`polymorphic_shapes` argument to `jax2tf.convert` for the second parameter |
|
|
of `jax_fn`. In this case, a single `input_signatures` is supported, and |
|
|
should have `None` in the polymorphic dimensions. This is required, for |
|
|
example, to have models with dynamic batch sizes. |
|
|
with_gradient: Whether the SavedModel should support gradients. If `True`, |
|
|
then a custom gradient is saved. If `False`, then a |
|
|
`tf.raw_ops.PreventGradient` is saved to error if a gradient is attempted. |
|
|
(At the moment due to a bug in SavedModel, custom gradients are not |
|
|
supported.) |
|
|
enable_xla: Whether the jax2tf converter is allowed to use TF XLA ops. If |
|
|
`False`, the conversion tries harder to use purely TF ops and raises an |
|
|
exception if it is not possible. |
|
|
compile_model: Use TensorFlow jit_compiler on the SavedModel. This |
|
|
is needed if the SavedModel will be used for TensorFlow serving. |
|
|
saved_model_options: Options to pass to `savedmodel.save`. |
|
|
native_serialization: Serialize the JAX function natively to |
|
|
StableHLO with compatibility guarantees. This makes it easier to have |
|
|
confidence that the code executed when calling this function from |
|
|
TensorFlow is exactly the same as JAX would run natively. See |
|
|
jax2tf.convert() for details. |
|
|
native_serialization_platforms: When the "native_serialization" flag is |
|
|
used, the platforms that it will be serialised to. Must be a tuple of |
|
|
strings, including a subset of: ['cpu', 'cuda', 'rocm', 'tpu']. |
|
|
'None', specifies the JAX default backend on the machine where the |
|
|
lowering is done. |
|
|
|
|
|
Raises: |
|
|
ValueError: If at least one input signature is not defined. However, if |
|
|
`polymorphic_shapes` is given, then only one input signature is supported. |
|
|
""" |
|
|
if not input_signatures: |
|
|
raise ValueError("At least one input_signature must be given.") |
|
|
if polymorphic_shapes is not None and len(input_signatures) > 1: |
|
|
raise ValueError("For shape-polymorphic conversion a single " |
|
|
"input_signature is supported.") |
|
|
tf_fn = jax2tf.convert( |
|
|
jax_fn, |
|
|
with_gradient=with_gradient, |
|
|
polymorphic_shapes=[None, polymorphic_shapes], |
|
|
enable_xla=enable_xla, |
|
|
native_serialization=native_serialization, |
|
|
native_serialization_platforms=native_serialization_platforms) |
|
|
|
|
|
def get_tf_variable(path, param): |
|
|
return tf.Variable(param, trainable=with_gradient, name="/".join(path)) |
|
|
|
|
|
param_vars = dm_tree.map_structure_with_path( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
get_tf_variable, params) |
|
|
tf_graph = tf.function( |
|
|
lambda inputs: tf_fn(param_vars, inputs), |
|
|
autograph=False, |
|
|
jit_compile=compile_model) |
|
|
|
|
|
|
|
|
signatures = { |
|
|
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: |
|
|
tf_graph.get_concrete_function(input_signatures[0]) |
|
|
} |
|
|
|
|
|
for input_signature in input_signatures[1:]: |
|
|
|
|
|
tf_graph.get_concrete_function(input_signature) |
|
|
wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars) |
|
|
|
|
|
if saved_model_options: |
|
|
saved_model_options.function_aliases = {"inference_func": tf_graph} |
|
|
else: |
|
|
saved_model_options = tf.saved_model.SaveOptions( |
|
|
function_aliases={"inference_func": tf_graph} |
|
|
) |
|
|
|
|
|
if with_gradient: |
|
|
saved_model_options.experimental_custom_gradients = True |
|
|
|
|
|
tf.saved_model.save( |
|
|
wrapper, model_dir, signatures=signatures, options=saved_model_options |
|
|
) |
|
|
|
|
|
|
|
|
class _ReusableSavedModelWrapper(tf.train.Checkpoint): |
|
|
"""Wraps a function and its parameters for saving to a SavedModel. |
|
|
|
|
|
Implements the interface described at |
|
|
https://www.tensorflow.org/hub/reusable_saved_models. |
|
|
""" |
|
|
|
|
|
def __init__(self, tf_graph: Callable[[PyTree], PyTree], param_vars: PyTree): |
|
|
"""Constructor. |
|
|
|
|
|
Args: |
|
|
tf_graph: A `tf.function` taking one argument (the inputs), which can be |
|
|
be tuples/lists/dictionaries of `np.ndarray` or tensors. The function |
|
|
may have references to the `tf.Variables` in `param_vars`. |
|
|
param_vars: The parameters, as tuples/lists/dictionaries of |
|
|
`tf.Variable`, to be saved as the variables of the SavedModel. |
|
|
""" |
|
|
super().__init__() |
|
|
self.variables = tf.nest.flatten(param_vars) |
|
|
self.trainable_variables = [v for v in self.variables if v.trainable] |
|
|
self.__call__ = tf_graph |
|
|
|