Spaces:
Sleeping
Sleeping
| # Copyright 2023 The Orbit Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Some layered modules/functions to help users writing custom training loop.""" | |
| import inspect | |
| import tensorflow as tf, tf_keras | |
| def create_global_step() -> tf.Variable: | |
| """Creates a `tf.Variable` suitable for use as a global step counter. | |
| Creating and managing a global step variable may be necessary for | |
| `AbstractTrainer` subclasses that perform multiple parameter updates per | |
| `Controller` "step", or use different optimizers on different steps. | |
| In these cases, an `optimizer.iterations` property generally can't be used | |
| directly, since it would correspond to parameter updates instead of iterations | |
| in the `Controller`'s training loop. Such use cases should simply call | |
| `step.assign_add(1)` at the end of each step. | |
| Returns: | |
| A non-trainable scalar `tf.Variable` of dtype `tf.int64`, with only the | |
| first replica's value retained when synchronizing across replicas in | |
| a distributed setting. | |
| """ | |
| return tf.Variable( | |
| 0, | |
| dtype=tf.int64, | |
| name="global_step", | |
| trainable=False, | |
| aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) | |
| def make_distributed_dataset(strategy, dataset_or_fn, *args, **kwargs): | |
| """A utility function to help create a `tf.distribute.DistributedDataset`. | |
| Args: | |
| strategy: An instance of `tf.distribute.Strategy`. | |
| dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function" | |
| returning a `tf.data.Dataset`. If it is a function, it may optionally have | |
| an argument named `input_context` which will be passed a | |
| `tf.distribute.InputContext` instance. | |
| *args: Any positional arguments to pass through to `dataset_or_fn`. | |
| **kwargs: Any keyword arguments to pass through to `dataset_or_fn`, except | |
| that the `input_options` keyword is used to specify a | |
| `tf.distribute.InputOptions` for making the distributed dataset. | |
| Returns: | |
| A distributed Dataset. | |
| """ | |
| if strategy is None: | |
| strategy = tf.distribute.get_strategy() | |
| input_options = kwargs.pop("input_options", None) | |
| if isinstance(dataset_or_fn, tf.data.Dataset): | |
| return strategy.experimental_distribute_dataset(dataset_or_fn, | |
| input_options) | |
| if not callable(dataset_or_fn): | |
| raise ValueError("`dataset_or_fn` should be either callable or an instance " | |
| "of `tf.data.Dataset`.") | |
| def dataset_fn(input_context): | |
| """Wraps `dataset_or_fn` for strategy.distribute_datasets_from_function.""" | |
| # If `dataset_or_fn` is a function and has an argument named | |
| # `input_context`, pass through the given `input_context`. Otherwise | |
| # `input_context` will be ignored. | |
| argspec = inspect.getfullargspec(dataset_or_fn) | |
| arg_names = argspec.args | |
| if "input_context" in arg_names: | |
| kwargs["input_context"] = input_context | |
| return dataset_or_fn(*args, **kwargs) | |
| return strategy.distribute_datasets_from_function(dataset_fn, input_options) | |
| def get_value(x): | |
| """Returns input values, converting any TensorFlow values to NumPy values. | |
| Args: | |
| x: The input. May be a `tf.Tensor` or `tf.Variable`. | |
| Returns: | |
| If the input is a TensorFlow `Tensor`, returns the `Tensor`'s equivalent | |
| NumPy value. Otherwise, just returns the input. | |
| """ | |
| if not tf.is_tensor(x): | |
| return x | |
| return x.numpy() | |