Spaces:
Runtime error
Runtime error
| # Copyright 2022 Google. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """TrainingTask encapsulates the state associated with model step.""" | |
| import time | |
| from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Tuple) | |
| from absl import logging | |
| from clu import metric_writers | |
| from flax import optim | |
| from flax import struct | |
| import jax | |
| import metrics_summary | |
| import numpy as np | |
| class TrainState: | |
| optimizer: optim.Optimizer # Trainable parameters. | |
| state: Any # Other state, e.g. XL cache or memory. | |
| PRNGKeys = Any | |
| Metrics = Dict[str, Any] | |
| MetricsSummary = metrics_summary.MetricsSummary | |
| Dataset = Callable[[], Iterator[Any]] | |
| StepFunction = Callable[[TrainState, Any, Any], Tuple[TrainState, Metrics]] | |
| PrettyPrintInputFunction = Optional[Callable[[Any], str]] | |
| ProcessSummariesFunction = Optional[Callable[[Any, str], Any]] | |
| ExtraSummariesFunction = Optional[Callable[[str, int], Mapping[str, Any]]] | |
| def should_run(step: int, every_steps: int) -> bool: | |
| """Returns true if a periodic action should be run.""" | |
| return (step > 0) and (every_steps > 0) and (step % every_steps == 0) | |
| class TrainingTask: | |
| """A TrainingTask encapsulates the state associated with a training task. | |
| Examples of tasks include training steps, test or validation runs, | |
| or inference (generation). State includes the input pipeline, and | |
| summary information that is averaged over multiple steps. | |
| """ | |
| def __init__( | |
| self, | |
| *, # Pass arguments by keyword only. | |
| mode: str, | |
| dataset: Dataset, | |
| step_function: StepFunction, | |
| prng_keys: PRNGKeys, | |
| summary: MetricsSummary, | |
| extra_summary: MetricsSummary, | |
| summary_writer: metric_writers.MetricWriter, | |
| summary_prefix: str = "", | |
| # --- Options from TrainingLoop --- | |
| replicate_mode: bool = True, | |
| print_input_every_steps: int = 0, | |
| pretty_print_input_function: PrettyPrintInputFunction = None, | |
| process_summaries_function: ProcessSummariesFunction = None, | |
| extra_summaries_function: Optional[ExtraSummariesFunction] = None): | |
| # Local state. | |
| self.mode = mode | |
| self.dataset = dataset | |
| self.step_function = step_function | |
| self.prng_keys = prng_keys | |
| self.summary = summary | |
| self.extra_summary = extra_summary | |
| self.summary_writer = summary_writer | |
| self.summary_prefix = summary_prefix | |
| # Options carried over from TrainingLoop. | |
| self.replicate_mode = replicate_mode | |
| self.print_input_every_steps = print_input_every_steps | |
| self.pretty_print_input_fn = pretty_print_input_function | |
| self.process_summaries_fn = process_summaries_function | |
| self.extra_summaries_fn = extra_summaries_function | |
| # Local state. | |
| if self.dataset is not None: | |
| self.ds_iterator = self.dataset() | |
| self.epoch = 0 | |
| def _get_metrics(self, device_metrics: Metrics) -> Metrics: | |
| """Read a dictionary of metrics from device.""" | |
| if self.replicate_mode: | |
| # x[0] gets the metric from device 0 -- the first replica. | |
| # We assume that merge_replicated_metrics has already combined the | |
| # metrics from multiple devices. | |
| device_metrics = jax.tree_map(lambda x: x[0], device_metrics) | |
| metrics_np = jax.device_get(device_metrics) # Get numpy arrays. | |
| return metrics_np | |
| def get_next_input(self) -> Any: | |
| """Grab the next input from the data pipeline.""" | |
| if self.dataset is None: | |
| logging.warning("No dataset for mode %s", self.mode) | |
| return None | |
| try: | |
| x = next(self.ds_iterator) | |
| except StopIteration: | |
| logging.info("End of epoch %d for mode %s.", self.epoch, self.mode) | |
| self.ds_iterator = self.dataset() | |
| x = next(self.ds_iterator) | |
| self.epoch += 1 | |
| return x | |
| def run_step(self, tstate: TrainState, x: Any, | |
| step: int, sub_step: int = 0) -> Tuple[TrainState, Metrics]: | |
| """Run the model for a single step. | |
| Args: | |
| tstate: The current model state. | |
| x: The input for the model -- from get_next_input. | |
| step: The training step number. | |
| sub_step: For tasks that run multiple iterations within a step. | |
| E.g. A test cycle will call run_step multiple times to cover the test | |
| set. The step counter will not increment, but sub_step will. | |
| Returns: | |
| An updated model state. | |
| """ | |
| start_time = time.perf_counter() | |
| # Split a batch of inputs among local replicas. | |
| if self.replicate_mode: | |
| x = split_batch_dimension(x, jax.local_device_count()) | |
| # Pretty-print the input to the summary and log file every so often. | |
| if (sub_step == 0 and self.pretty_print_input_fn is not None and | |
| should_run(step, self.print_input_every_steps)): | |
| x_first = jax.tree_map(lambda x: x[0], x) if self.replicate_mode else x | |
| x_strs = self.pretty_print_input_fn(x_first) | |
| logging.info("[%d] Input (%s) = %s", step, self.mode, x_strs) | |
| self.summary.add_text({"input": x_strs}) | |
| # Run the step function on the input. | |
| with jax.profiler.StepTraceAnnotation(self.mode, step_num=step): | |
| (tstate, metrics) = self.step_function(tstate, x, self.prng_keys) | |
| # Read metrics from device. | |
| metrics_np = self._get_metrics(metrics) | |
| end_time = time.perf_counter() | |
| metrics_np["step_time"] = end_time - start_time | |
| if "epoch" not in metrics_np.keys(): | |
| metrics_np["epoch"] = self.epoch | |
| # Add metrics to the current summary. | |
| self.summary.add(metrics_np) | |
| return (tstate, metrics_np) | |
| def flush(self, step: int): | |
| """Flush accumulated metric summaries to disk.""" | |
| if self.summary_writer is None: | |
| self.summary.clear() # Clear summary if we can't write it. | |
| return | |
| if self.summary.empty(): | |
| return | |
| # Do post-processing of the summaries. | |
| if self.process_summaries_fn is not None: | |
| self.summary = self.process_summaries_fn(self.summary, self.mode) # pylint: disable=not-callable | |
| # Write and clear summary data. | |
| logging.info("Writing summaries for mode %s.", self.mode) | |
| self.summary.write(self.summary_writer, step, prefix=self.summary_prefix) | |
| # Add extra summaries that are not computed by the step function. | |
| if self.extra_summaries_fn is not None: | |
| self.extra_summary.add(self.extra_summaries_fn(self.mode, step)) | |
| self.extra_summary.write(self.summary_writer, step, prefix="") | |
| def split_batch_dimension(inputs: Any, num_replicas: int) -> Any: | |
| """Splits the leading batch dimension. | |
| Given inputs of shape [num_replicas * batch_size, ...], it will reshape | |
| them to [num_replicas, batch_size, ...]. This operation is intended to be | |
| used right before calling pmap, which will eliminate the num_replicas | |
| dimension. | |
| Args: | |
| inputs: Tuple of inputs to split. | |
| num_replicas: Number of replicas. | |
| Returns: | |
| inputs with extra batch dimension. | |
| """ | |
| def split_batch_dim(x): | |
| assert x.ndim > 0 | |
| if (x.shape[0] % num_replicas) != 0: | |
| raise ValueError(f"Can't split {x.shape} into {num_replicas} replicas.") | |
| batch_size = x.shape[0] // num_replicas | |
| split_shape = [num_replicas, batch_size] + list(x.shape[1:]) | |
| return np.reshape(x, split_shape) | |
| return jax.tree_map(split_batch_dim, inputs) | |