File size: 8,250 Bytes
742a3d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright 2024 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helper functions for exporting JAX models to Tensorflow SavedModels."""

from typing import Any, Callable, Sequence, Optional, Union

from jax.experimental import jax2tf
import tensorflow as tf
import tree as dm_tree

# JAX team is working on type annotation for pytree:
# https://github.com/jax-ml/jax/issues/1555
# A PyTree is a nested dictionary where the leaves are `jnp.ndarray`.
# TODO(aarnab): Fix type annotation once ready.
PyTree = Any


def convert_and_save_model(
    jax_fn: Callable[[PyTree, PyTree], PyTree],
    params: PyTree,
    model_dir: str,
    *,
    input_signatures: Union[
        Sequence[tf.TensorSpec],
        Sequence[Sequence[tf.TensorSpec]],
        Sequence[dict[str, tf.TensorSpec]],
    ],
    polymorphic_shapes: Optional[
        Union[str, jax2tf.PolyShape, dict[str, str]]
    ] = None,
    with_gradient: bool = False,
    enable_xla: bool = True,
    compile_model: bool = True,
    saved_model_options: Optional[tf.saved_model.SaveOptions] = None,
    native_serialization: Optional[str | bool] = "default",
    native_serialization_platforms: Sequence[str] | None = ("cpu", "tpu")):
  """Converts a JAX function and saves a SavedModel.

  We assume that the JAX model consists of a prediction function and trained
  parameters, and the computation graph of the function is saved separately from
  the parameters. Saving the graph separately from the parameters reduces
  the size of the Tensorflow `GraphDef`, and enables finetuning of model
  parameters too.

  To use this function, a JAX model must be converted to a function of two
  arguments, the model parameters and the input.
  For a Scenic model, this corresponds to:
  ```
  params = train_state.optimizer.target
  flax_model = model.flax_model
  def _predict_fn(params, input_data):
    return flax_model.apply({'params': params}, input_data, train=False)
  ```

  Args:
    jax_fn: A JAX function taking two arguments, the parameters and the inputs.
      Both arguments may be (nested) tuples/lists/dictionaries of `np.ndarray`.
      It is necessary to be able to JIT-compile this function (ie run
      `jax.jit` on it).
    params: The parameters, to be used as first argument for `jax_fn`. These
      must be (nested) tuples/lists/dictionaries of `np.ndarray`, and will be
      saved as the variables of the SavedModel.
    model_dir: The directory where the model should be saved.
    input_signatures: The input signatures for the second argument of `jax_fn`
      (the input). A signature must be a `tensorflow.TensorSpec` instance, or a
      (nested) tuple/list/dictionary thereof with a structure matching the
      second argument of `jax_fn`. The first input_signature will be saved as
      the default serving signature. The additional signatures will be used
      only to ensure that the `jax_fn` is traced and converted to TF for the
      corresponding input shapes.
    polymorphic_shapes: If given then it will be used as the
      `polymorphic_shapes` argument to `jax2tf.convert` for the second parameter
      of `jax_fn`. In this case, a single `input_signatures` is supported, and
      should have `None` in the polymorphic dimensions. This is required, for
      example, to have models with dynamic batch sizes.
    with_gradient: Whether the SavedModel should support gradients. If `True`,
      then a custom gradient is saved. If `False`, then a
      `tf.raw_ops.PreventGradient` is saved to error if a gradient is attempted.
      (At the moment due to a bug in SavedModel, custom gradients are not
      supported.)
    enable_xla: Whether the jax2tf converter is allowed to use TF XLA ops. If
      `False`, the conversion tries harder to use purely TF ops and raises an
      exception if it is not possible.
    compile_model: Use TensorFlow jit_compiler on the SavedModel. This
      is needed if the SavedModel will be used for TensorFlow serving.
    saved_model_options: Options to pass to `savedmodel.save`.
    native_serialization: Serialize the JAX function natively to
      StableHLO with compatibility guarantees. This makes it easier to have
      confidence that the code executed when calling this function from
      TensorFlow is exactly the same as JAX would run natively. See
      jax2tf.convert() for details.
    native_serialization_platforms: When the "native_serialization" flag is
      used, the platforms that it will be serialised to. Must be a tuple of
      strings, including a subset of: ['cpu', 'cuda', 'rocm', 'tpu'].
      'None', specifies the JAX default backend on the machine where the
      lowering is done.

  Raises:
    ValueError: If at least one input signature is not defined. However, if
    `polymorphic_shapes` is given, then only one input signature is supported.
  """
  if not input_signatures:
    raise ValueError("At least one input_signature must be given.")
  if polymorphic_shapes is not None and len(input_signatures) > 1:
    raise ValueError("For shape-polymorphic conversion a single "
                     "input_signature is supported.")
  tf_fn = jax2tf.convert(
      jax_fn,
      with_gradient=with_gradient,
      polymorphic_shapes=[None, polymorphic_shapes],
      enable_xla=enable_xla,
      native_serialization=native_serialization,
      native_serialization_platforms=native_serialization_platforms)

  def get_tf_variable(path, param):
    return tf.Variable(param, trainable=with_gradient, name="/".join(path))

  param_vars = dm_tree.map_structure_with_path(
      # Due to a bug in SavedModel it is not possible to use `tf.GradientTape`
      # on a function converted with jax2tf and loaded from SavedModel. Thus, we
      # mark the variables as non-trainable to ensure that users of the
      # SavedModel will not try to fine tune them.
      get_tf_variable, params)
  tf_graph = tf.function(
      lambda inputs: tf_fn(param_vars, inputs),
      autograph=False,
      jit_compile=compile_model)

  # This signature is needed for TensorFlow Serving use.
  signatures = {
      tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          tf_graph.get_concrete_function(input_signatures[0])
  }

  for input_signature in input_signatures[1:]:
    # If there are more signatures, trace and cache a TF function for each one.
    tf_graph.get_concrete_function(input_signature)
  wrapper = _ReusableSavedModelWrapper(tf_graph, param_vars)

  if saved_model_options:
    saved_model_options.function_aliases = {"inference_func": tf_graph}
  else:
    saved_model_options = tf.saved_model.SaveOptions(
        function_aliases={"inference_func": tf_graph}
    )

  if with_gradient:
    saved_model_options.experimental_custom_gradients = True

  tf.saved_model.save(
      wrapper, model_dir, signatures=signatures, options=saved_model_options
  )


class _ReusableSavedModelWrapper(tf.train.Checkpoint):
  """Wraps a function and its parameters for saving to a SavedModel.

  Implements the interface described at
  https://www.tensorflow.org/hub/reusable_saved_models.
  """

  def __init__(self, tf_graph: Callable[[PyTree], PyTree], param_vars: PyTree):
    """Constructor.

    Args:
      tf_graph: A `tf.function` taking one argument (the inputs), which can be
        be tuples/lists/dictionaries of `np.ndarray` or tensors. The function
        may have references to the `tf.Variables` in `param_vars`.
      param_vars: The parameters, as tuples/lists/dictionaries of
        `tf.Variable`, to be saved as the variables of the SavedModel.
    """
    super().__init__()
    self.variables = tf.nest.flatten(param_vars)
    self.trainable_variables = [v for v in self.variables if v.trainable]
    self.__call__ = tf_graph