Spaces:
Runtime error
Runtime error
| # Lint as: python3 | |
| # Copyright 2020 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Defines the base task abstraction.""" | |
| import abc | |
| import functools | |
| from typing import Any, Callable, Optional | |
| import six | |
| import tensorflow as tf | |
| from official.modeling.hyperparams import config_definitions as cfg | |
| from official.utils import registry | |
| class Task(tf.Module): | |
| """A single-replica view of training procedure. | |
| Tasks provide artifacts for training/evalution procedures, including | |
| loading/iterating over Datasets, initializing the model, calculating the loss | |
| and customized metrics with reduction. | |
| """ | |
| # Special keys in train/validate step returned logs. | |
| loss = "loss" | |
| def __init__(self, params: cfg.TaskConfig): | |
| self._task_config = params | |
| def task_config(self) -> cfg.TaskConfig: | |
| return self._task_config | |
| def initialize(self, model: tf.keras.Model): | |
| """A callback function used as CheckpointManager's init_fn. | |
| This function will be called when no checkpoint found for the model. | |
| If there is a checkpoint, the checkpoint will be loaded and this function | |
| will not be called. You can use this callback function to load a pretrained | |
| checkpoint, saved under a directory other than the model_dir. | |
| Args: | |
| model: The keras.Model built or used by this task. | |
| """ | |
| pass | |
| def build_model(self) -> tf.keras.Model: | |
| """Creates the model architecture. | |
| Returns: | |
| A model instance. | |
| """ | |
| def compile_model(self, | |
| model: tf.keras.Model, | |
| optimizer: tf.keras.optimizers.Optimizer, | |
| loss=None, | |
| train_step: Optional[Callable[..., Any]] = None, | |
| validation_step: Optional[Callable[..., Any]] = None, | |
| **kwargs) -> tf.keras.Model: | |
| """Compiles the model with objects created by the task. | |
| The method should not be used in any customized training implementation. | |
| Args: | |
| model: a keras.Model. | |
| optimizer: the keras optimizer. | |
| loss: a callable/list of losses. | |
| train_step: optional train step function defined by the task. | |
| validation_step: optional validation_step step function defined by the | |
| task. | |
| **kwargs: other kwargs consumed by keras.Model compile(). | |
| Returns: | |
| a compiled keras.Model. | |
| """ | |
| if bool(loss is None) == bool(train_step is None): | |
| raise ValueError("`loss` and `train_step` should be exclusive to " | |
| "each other.") | |
| model.compile(optimizer=optimizer, loss=loss, **kwargs) | |
| if train_step: | |
| model.train_step = functools.partial( | |
| train_step, model=model, optimizer=model.optimizer) | |
| if validation_step: | |
| model.test_step = functools.partial(validation_step, model=model) | |
| return model | |
| def build_inputs(self, | |
| params: cfg.DataConfig, | |
| input_context: Optional[tf.distribute.InputContext] = None): | |
| """Returns a dataset or a nested structure of dataset functions. | |
| Dataset functions define per-host datasets with the per-replica batch size. | |
| Args: | |
| params: hyperparams to create input pipelines. | |
| input_context: optional distribution input pipeline context. | |
| Returns: | |
| A nested structure of per-replica input functions. | |
| """ | |
| def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: | |
| """Standard interface to compute losses. | |
| Args: | |
| labels: optional label tensors. | |
| model_outputs: a nested structure of output tensors. | |
| aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. | |
| Returns: | |
| The total loss tensor. | |
| """ | |
| del model_outputs, labels | |
| if aux_losses is None: | |
| losses = [tf.constant(0.0, dtype=tf.float32)] | |
| else: | |
| losses = aux_losses | |
| total_loss = tf.add_n(losses) | |
| return total_loss | |
| def build_metrics(self, training: bool = True): | |
| """Gets streaming metrics for training/validation.""" | |
| del training | |
| return [] | |
| def process_metrics(self, metrics, labels, model_outputs): | |
| """Process and update metrics. Called when using custom training loop API. | |
| Args: | |
| metrics: a nested structure of metrics objects. | |
| The return of function self.build_metrics. | |
| labels: a tensor or a nested structure of tensors. | |
| model_outputs: a tensor or a nested structure of tensors. | |
| For example, output of the keras model built by self.build_model. | |
| """ | |
| for metric in metrics: | |
| metric.update_state(labels, model_outputs) | |
| def process_compiled_metrics(self, compiled_metrics, labels, model_outputs): | |
| """Process and update compiled_metrics. call when using compile/fit API. | |
| Args: | |
| compiled_metrics: the compiled metrics (model.compiled_metrics). | |
| labels: a tensor or a nested structure of tensors. | |
| model_outputs: a tensor or a nested structure of tensors. | |
| For example, output of the keras model built by self.build_model. | |
| """ | |
| compiled_metrics.update_state(labels, model_outputs) | |
| def train_step(self, | |
| inputs, | |
| model: tf.keras.Model, | |
| optimizer: tf.keras.optimizers.Optimizer, | |
| metrics=None): | |
| """Does forward and backward. | |
| Args: | |
| inputs: a dictionary of input tensors. | |
| model: the model, forward pass definition. | |
| optimizer: the optimizer for this training step. | |
| metrics: a nested structure of metrics objects. | |
| Returns: | |
| A dictionary of logs. | |
| """ | |
| if isinstance(inputs, tuple) and len(inputs) == 2: | |
| features, labels = inputs | |
| else: | |
| features, labels = inputs, inputs | |
| with tf.GradientTape() as tape: | |
| outputs = model(features, training=True) | |
| # Computes per-replica loss. | |
| loss = self.build_losses( | |
| labels=labels, model_outputs=outputs, aux_losses=model.losses) | |
| # Scales loss as the default gradients allreduce performs sum inside the | |
| # optimizer. | |
| scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync | |
| # For mixed precision, when a LossScaleOptimizer is used, the loss is | |
| # scaled to avoid numeric underflow. | |
| if isinstance(optimizer, | |
| tf.keras.mixed_precision.experimental.LossScaleOptimizer): | |
| scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
| tvars = model.trainable_variables | |
| grads = tape.gradient(scaled_loss, tvars) | |
| if isinstance(optimizer, | |
| tf.keras.mixed_precision.experimental.LossScaleOptimizer): | |
| grads = optimizer.get_unscaled_gradients(grads) | |
| optimizer.apply_gradients(list(zip(grads, tvars))) | |
| logs = {self.loss: loss} | |
| if metrics: | |
| self.process_metrics(metrics, labels, outputs) | |
| logs.update({m.name: m.result() for m in metrics}) | |
| elif model.compiled_metrics: | |
| self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
| logs.update({m.name: m.result() for m in model.metrics}) | |
| return logs | |
| def validation_step(self, inputs, model: tf.keras.Model, metrics=None): | |
| """Validatation step. | |
| Args: | |
| inputs: a dictionary of input tensors. | |
| model: the keras.Model. | |
| metrics: a nested structure of metrics objects. | |
| Returns: | |
| A dictionary of logs. | |
| """ | |
| if isinstance(inputs, tuple) and len(inputs) == 2: | |
| features, labels = inputs | |
| else: | |
| features, labels = inputs, inputs | |
| outputs = self.inference_step(features, model) | |
| loss = self.build_losses( | |
| labels=labels, model_outputs=outputs, aux_losses=model.losses) | |
| logs = {self.loss: loss} | |
| if metrics: | |
| self.process_metrics(metrics, labels, outputs) | |
| logs.update({m.name: m.result() for m in metrics}) | |
| elif model.compiled_metrics: | |
| self.process_compiled_metrics(model.compiled_metrics, labels, outputs) | |
| logs.update({m.name: m.result() for m in model.metrics}) | |
| return logs | |
| def inference_step(self, inputs, model: tf.keras.Model): | |
| """Performs the forward step.""" | |
| return model(inputs, training=False) | |
| def aggregate_logs(self, state, step_logs): | |
| """Optional aggregation over logs returned from a validation step.""" | |
| pass | |
| def reduce_aggregated_logs(self, aggregated_logs): | |
| """Optional reduce of aggregated logs over validation steps.""" | |
| return {} | |
| _REGISTERED_TASK_CLS = {} | |
| # TODO(b/158268740): Move these outside the base class file. | |
| # TODO(b/158741360): Add type annotations once pytype checks across modules. | |
| def register_task_cls(task_config_cls): | |
| """Decorates a factory of Tasks for lookup by a subclass of TaskConfig. | |
| This decorator supports registration of tasks as follows: | |
| ``` | |
| @dataclasses.dataclass | |
| class MyTaskConfig(TaskConfig): | |
| # Add fields here. | |
| pass | |
| @register_task_cls(MyTaskConfig) | |
| class MyTask(Task): | |
| # Inherits def __init__(self, task_config). | |
| pass | |
| my_task_config = MyTaskConfig() | |
| my_task = get_task(my_task_config) # Returns MyTask(my_task_config). | |
| ``` | |
| Besisdes a class itself, other callables that create a Task from a TaskConfig | |
| can be decorated by the result of this function, as long as there is at most | |
| one registration for each config class. | |
| Args: | |
| task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig). | |
| Each task_config_cls can only be used for a single registration. | |
| Returns: | |
| A callable for use as class decorator that registers the decorated class | |
| for creation from an instance of task_config_cls. | |
| """ | |
| return registry.register(_REGISTERED_TASK_CLS, task_config_cls) | |
| # The user-visible get_task() is defined after classes have been registered. | |
| # TODO(b/158741360): Add type annotations once pytype checks across modules. | |
| def get_task_cls(task_config_cls): | |
| task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls) | |
| return task_cls | |