|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Base class for model export."""
|
|
|
| import abc
|
| import functools
|
| import time
|
| from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
|
|
| MAX_DIRECTORY_CREATION_ATTEMPTS = 10
|
|
|
|
|
| class ExportModule(tf.Module, metaclass=abc.ABCMeta):
|
| """Base Export Module."""
|
|
|
| def __init__(self,
|
| params,
|
| model: Union[tf.Module, tf_keras.Model],
|
| inference_step: Optional[Callable[..., Any]] = None,
|
| *,
|
| preprocessor: Optional[Callable[..., Any]] = None,
|
| postprocessor: Optional[Callable[..., Any]] = None):
|
| """Instantiates an ExportModel.
|
|
|
| Examples:
|
|
|
| `inference_step` must be a function that has `model` as an kwarg or the
|
| second positional argument.
|
| ```
|
| def _inference_step(inputs, model=None):
|
| return model(inputs, training=False)
|
|
|
| module = ExportModule(params, model, inference_step=_inference_step)
|
| ```
|
|
|
| `preprocessor` and `postprocessor` could be either functions or `tf.Module`.
|
| The usages of preprocessor and postprocessor are managed by the
|
| implementation of `serve()` method.
|
|
|
| Args:
|
| params: A dataclass for parameters to the module.
|
| model: A model instance which contains weights and forward computation.
|
| inference_step: An optional callable to forward-pass the model. If not
|
| specified, it creates a parital function with `model` as an required
|
| kwarg.
|
| preprocessor: An optional callable to preprocess the inputs.
|
| postprocessor: An optional callable to postprocess the model outputs.
|
| """
|
| super().__init__(name=None)
|
| self.model = model
|
| self.params = params
|
|
|
| if inference_step is not None:
|
| self.inference_step = functools.partial(inference_step, model=self.model)
|
| else:
|
| if issubclass(type(model), tf_keras.Model):
|
|
|
|
|
|
|
|
|
|
|
|
|
| self.inference_step = self.model.call
|
| else:
|
| self.inference_step = functools.partial(
|
| self.model.__call__, training=False)
|
| self.preprocessor = preprocessor
|
| self.postprocessor = postprocessor
|
|
|
| @abc.abstractmethod
|
| def serve(self) -> Mapping[Text, tf.Tensor]:
|
| """The bare inference function which should run on all devices.
|
|
|
| Expecting tensors are passed in through keyword arguments. Returns a
|
| dictionary of tensors, when the keys will be used inside the SignatureDef.
|
| """
|
|
|
| @abc.abstractmethod
|
| def get_inference_signatures(
|
| self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
|
| """Get defined function signatures."""
|
|
|
|
|
| def export(export_module: ExportModule,
|
| function_keys: Union[List[Text], Dict[Text, Text]],
|
| export_savedmodel_dir: Text,
|
| checkpoint_path: Optional[Text] = None,
|
| timestamped: bool = True,
|
| save_options: Optional[tf.saved_model.SaveOptions] = None,
|
| checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
|
| """Exports to SavedModel format.
|
|
|
| Args:
|
| export_module: a ExportModule with the keras Model and serving tf.functions.
|
| function_keys: a list of string keys to retrieve pre-defined serving
|
| signatures. The signaute keys will be set with defaults. If a dictionary
|
| is provided, the values will be used as signature keys.
|
| export_savedmodel_dir: Output saved model directory.
|
| checkpoint_path: Object-based checkpoint path or directory.
|
| timestamped: Whether to export the savedmodel to a timestamped directory.
|
| save_options: `SaveOptions` for `tf.saved_model.save`.
|
| checkpoint: An optional tf.train.Checkpoint. If provided, the export module
|
| will use it to read the weights.
|
|
|
| Returns:
|
| The savedmodel directory path.
|
| """
|
| ckpt_dir_or_file = checkpoint_path
|
| if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
|
| ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
| if ckpt_dir_or_file:
|
| if checkpoint is None:
|
| checkpoint = tf.train.Checkpoint(model=export_module.model)
|
| checkpoint.read(
|
| ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
|
| if isinstance(function_keys, list):
|
| if len(function_keys) == 1:
|
| function_keys = {
|
| function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
| }
|
| else:
|
| raise ValueError(
|
| 'If the function_keys is a list, it must contain a single element. %s'
|
| % function_keys)
|
|
|
| signatures = export_module.get_inference_signatures(function_keys)
|
| if timestamped:
|
| export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
|
| 'utf-8')
|
| else:
|
| export_dir = export_savedmodel_dir
|
| tf.saved_model.save(
|
| export_module, export_dir, signatures=signatures, options=save_options)
|
| return export_dir
|
|
|
|
|
| def get_timestamped_export_dir(export_dir_base):
|
| """Builds a path to a new subdirectory within the base directory.
|
|
|
| Args:
|
| export_dir_base: A string containing a directory to write the exported graph
|
| and checkpoints.
|
|
|
| Returns:
|
| The full path of the new subdirectory (which is not actually created yet).
|
|
|
| Raises:
|
| RuntimeError: if repeated attempts fail to obtain a unique timestamped
|
| directory name.
|
| """
|
| attempts = 0
|
| while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
|
| timestamp = int(time.time())
|
|
|
| result_dir = tf.io.gfile.join(
|
| tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
|
| if not tf.io.gfile.exists(result_dir):
|
|
|
|
|
|
|
| return result_dir
|
| time.sleep(1)
|
| attempts += 1
|
| logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
|
| str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
|
| raise RuntimeError('Failed to obtain a unique export directory name after '
|
| f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
|
|
|