Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Training utils.""" | |
| import dataclasses | |
| import inspect | |
| import json | |
| import os | |
| import pprint | |
| from typing import Any, Callable, Dict, List, Optional, Union | |
| from absl import logging | |
| import gin | |
| import numpy as np | |
| import orbit | |
| import tensorflow as tf, tf_keras | |
| # pylint: disable=g-direct-tensorflow-import | |
| from tensorflow.python.framework import ops | |
| from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph | |
| # pylint: enable=g-direct-tensorflow-import | |
| from official.core import base_task | |
| from official.core import base_trainer | |
| from official.core import config_definitions | |
| from official.core import exp_factory | |
| from official.modeling import hyperparams | |
| BEST_CHECKPOINT_NAME = 'best_ckpt' | |
| def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: | |
| """Get leaf from a dictionary with arbitrary depth with a list of keys. | |
| Args: | |
| d: The dictionary to extract value from. | |
| keys: The list of keys to extract values recursively. | |
| Returns: | |
| The value of the leaf. | |
| Raises: | |
| KeyError: If the value of keys extracted is a dictionary. | |
| """ | |
| leaf = d | |
| for k in keys: | |
| if not isinstance(leaf, dict) or k not in leaf: | |
| raise KeyError( | |
| 'Path not exist while traversing the dictionary: d with keys' | |
| ': %s.' % keys) | |
| leaf = leaf[k] | |
| if isinstance(leaf, dict): | |
| raise KeyError('The value extracted with keys: %s is not a leaf of the ' | |
| 'dictionary: %s.' % (keys, d)) | |
| return leaf | |
| def cast_leaf_nested_dict(d: Dict[str, Any], | |
| cast_fn: Callable[[Any], Any]) -> Dict[str, Any]: | |
| """Cast the leaves of a dictionary with arbitrary depth in place. | |
| Args: | |
| d: The dictionary to extract value from. | |
| cast_fn: The casting function. | |
| Returns: | |
| A dictionray with the same structure as d. | |
| """ | |
| for key, value in d.items(): | |
| if isinstance(value, dict): | |
| d[key] = cast_leaf_nested_dict(value, cast_fn) | |
| else: | |
| d[key] = cast_fn(value) | |
| return d | |
| def _filter_leaf_nested_dict( | |
| d: Dict[str, Any], predicate: Callable[[Any], bool] | |
| ) -> Dict[str, Any]: | |
| """Filters the leaves of a dictionary with arbitrary depth in place. | |
| Args: | |
| d: The dictionary to extract value from. | |
| predicate: A function that will be called on every leave item. When the | |
| function returns True the leave will be kept. Otherwise the leave will be | |
| dropped. | |
| Returns: | |
| A new dictionray with filtered result. | |
| """ | |
| result = {} | |
| for key, value in d.items(): | |
| if isinstance(value, dict): | |
| result[key] = _filter_leaf_nested_dict(value, predicate) | |
| elif predicate(value): | |
| result[key] = value | |
| return result | |
| def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig, | |
| data_dir: str) -> Any: | |
| """Maybe create a BestCheckpointExporter object, according to the config.""" | |
| export_subdir = params.trainer.best_checkpoint_export_subdir | |
| metric_name = params.trainer.best_checkpoint_eval_metric | |
| metric_comp = params.trainer.best_checkpoint_metric_comp | |
| if data_dir and export_subdir and metric_name: | |
| best_ckpt_dir = os.path.join(data_dir, export_subdir) | |
| best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name, | |
| metric_comp) | |
| logging.info( | |
| 'Created the best checkpoint exporter. ' | |
| 'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir, | |
| export_subdir, metric_name) | |
| else: | |
| best_ckpt_exporter = None | |
| return best_ckpt_exporter | |
| class BestCheckpointExporter: | |
| """Keeps track of the best result, and saves its checkpoint. | |
| Orbit will support an API for checkpoint exporter. This class will be used | |
| together with orbit once this functionality is ready. | |
| """ | |
| def __init__(self, export_dir: str, metric_name: str, metric_comp: str): | |
| """Initialization. | |
| Args: | |
| export_dir: The directory that will contain exported checkpoints. | |
| metric_name: Indicates which metric to look at, when determining which | |
| result is better. If eval_logs being passed to maybe_export_checkpoint | |
| is a nested dictionary, use `|` as a seperator for different layers. | |
| metric_comp: Indicates how to compare results. Either `lower` or `higher`. | |
| """ | |
| self._export_dir = export_dir | |
| self._metric_name = metric_name.split('|') | |
| self._metric_comp = metric_comp | |
| if self._metric_comp not in ('lower', 'higher'): | |
| raise ValueError('best checkpoint metric comp must be one of ' | |
| 'higher, lower. Got: {}'.format(self._metric_comp)) | |
| tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path)) | |
| self._best_ckpt_logs = self._maybe_load_best_eval_metric() | |
| self._checkpoint_manager = None | |
| def _get_checkpoint_manager(self, checkpoint): | |
| """Gets an existing checkpoint manager or creates a new one.""" | |
| if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint | |
| != checkpoint): | |
| logging.info('Creates a new checkpoint manager.') | |
| self._checkpoint_manager = tf.train.CheckpointManager( | |
| checkpoint, | |
| directory=self._export_dir, | |
| max_to_keep=1, | |
| checkpoint_name=BEST_CHECKPOINT_NAME) | |
| return self._checkpoint_manager | |
| def maybe_export_checkpoint( | |
| self, checkpoint, eval_logs, global_step, write_logs=True) -> bool: | |
| """Compare eval_logs with past eval_logs and export checkpoint if better.""" | |
| logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d', | |
| eval_logs, global_step) | |
| if self._best_ckpt_logs is None or self._new_metric_is_better( | |
| self._best_ckpt_logs, eval_logs): | |
| self._best_ckpt_logs = eval_logs | |
| if write_logs: | |
| self.export_best_eval_metric(self._best_ckpt_logs, global_step) | |
| self._get_checkpoint_manager(checkpoint).save() | |
| return True | |
| return False | |
| def _maybe_load_best_eval_metric(self): | |
| if not tf.io.gfile.exists(self.best_ckpt_logs_path): | |
| return None | |
| with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader: | |
| return json.loads(reader.read()) | |
| def _new_metric_is_better(self, old_logs, new_logs): | |
| """Check if the metric in new_logs is better than the metric in old_logs.""" | |
| old_value = float( | |
| orbit.utils.get_value( | |
| get_leaf_nested_dict(old_logs, self._metric_name))) | |
| new_value = float( | |
| orbit.utils.get_value( | |
| get_leaf_nested_dict(new_logs, self._metric_name))) | |
| logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f', | |
| old_value, new_value) | |
| if self._metric_comp == 'higher': | |
| if new_value > old_value: | |
| logging.info('[BestCheckpointExporter] ' | |
| 'the new number is better since it is higher.') | |
| return True | |
| else: # self._metric_comp == 'lower': | |
| if new_value < old_value: | |
| logging.info('[BestCheckpointExporter] ' | |
| 'the new number is better since it is lower.') | |
| return True | |
| return False | |
| def export_best_eval_metric(self, eval_logs, global_step): | |
| """Export evaluation results of the best checkpoint into a json file.""" | |
| # eval_log_ext may contains non-scalar tensors, such as image data when | |
| # `allow_image_summary` is True. Here we only keep scalar tensors. | |
| eval_logs_ext = _filter_leaf_nested_dict( | |
| eval_logs, lambda x: tf.rank(x) <= 1 | |
| ) | |
| eval_logs_ext['best_ckpt_global_step'] = global_step | |
| eval_logs_ext = cast_leaf_nested_dict( | |
| eval_logs_ext, lambda x: float(orbit.utils.get_value(x))) | |
| # Saving json file is very fast. | |
| with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: | |
| writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') | |
| def best_ckpt_logs(self): | |
| return self._best_ckpt_logs | |
| def best_ckpt_logs_path(self): | |
| return os.path.join(self._export_dir, 'info.json') | |
| def best_ckpt_path(self): | |
| """Returns the best ckpt path or None if there is no ckpt yet.""" | |
| return tf.train.latest_checkpoint(self._export_dir) | |
| def create_optimizer(task: base_task.Task, | |
| params: config_definitions.ExperimentConfig | |
| ) -> tf_keras.optimizers.Optimizer: | |
| """A create optimizer util to be backward compatability with new args.""" | |
| if 'dp_config' in inspect.signature(task.create_optimizer).parameters: | |
| dp_config = None | |
| if hasattr(params.task, 'differential_privacy_config'): | |
| dp_config = params.task.differential_privacy_config | |
| optimizer = task.create_optimizer( | |
| params.trainer.optimizer_config, params.runtime, | |
| dp_config=dp_config) | |
| else: | |
| if hasattr(params.task, 'differential_privacy_config' | |
| ) and params.task.differential_privacy_config is not None: | |
| raise ValueError('Differential privacy config is specified but ' | |
| 'task.create_optimizer api does not accept it.') | |
| optimizer = task.create_optimizer( | |
| params.trainer.optimizer_config, | |
| params.runtime) | |
| return optimizer | |
| def create_trainer(params: config_definitions.ExperimentConfig, | |
| task: base_task.Task, | |
| train: bool, | |
| evaluate: bool, | |
| checkpoint_exporter: Optional[BestCheckpointExporter] = None, | |
| trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer: | |
| """Create trainer.""" | |
| logging.info('Running default trainer.') | |
| model = task.build_model() | |
| optimizer = create_optimizer(task, params) | |
| return trainer_cls( | |
| params, | |
| task, | |
| model=model, | |
| optimizer=optimizer, | |
| train=train, | |
| evaluate=evaluate, | |
| checkpoint_exporter=checkpoint_exporter) | |
| class ParseConfigOptions: | |
| """Use this dataclass instead of FLAGS to customize parse_configuration().""" | |
| experiment: str | |
| config_file: List[str] | |
| tpu: str = '' | |
| tf_data_service: str = '' | |
| params_override: str = '' | |
| def __contains__(self, name): | |
| return name in dataclasses.asdict(self) | |
| class ExperimentParser: | |
| """Constructs the Experiment config from Flags or equivalent object. | |
| Most of the cases, users only need to call the `parse()` function: | |
| ``` | |
| builder = ExperimentParser(FLAGS) | |
| params = builder.parse() | |
| ``` | |
| The advanced users can modify the flow by calling the parse_*() functions | |
| separately. | |
| """ | |
| def __init__(self, flags_obj): | |
| self._flags_obj = flags_obj | |
| def parse(self): | |
| """Overrall process of constructing Experiment config.""" | |
| params = self.base_experiment() | |
| params = self.parse_config_file(params) | |
| params = self.parse_runtime(params) | |
| params = self.parse_data_service(params) | |
| params = self.parse_params_override(params) | |
| return params | |
| def base_experiment(self): | |
| """Get the base experiment config from --experiment field.""" | |
| if self._flags_obj.experiment is None: | |
| raise ValueError('The flag --experiment must be specified.') | |
| return exp_factory.get_exp_config(self._flags_obj.experiment) | |
| def parse_config_file(self, params): | |
| """Override the configs of params from the config_file.""" | |
| for config_file in self._flags_obj.config_file or []: | |
| params = hyperparams.override_params_dict( | |
| params, config_file, is_strict=True) | |
| return params | |
| def parse_runtime(self, params): | |
| """Override the runtime configs of params from flags.""" | |
| # Override the TPU address and tf.data service address. | |
| params.override({ | |
| 'runtime': { | |
| 'tpu': self._flags_obj.tpu, | |
| }, | |
| }) | |
| return params | |
| def parse_data_service(self, params): | |
| """Override the data service configs of params from flags.""" | |
| if ('tf_data_service' in self._flags_obj and | |
| self._flags_obj.tf_data_service and | |
| isinstance(params.task, config_definitions.TaskConfig)): | |
| params.override({ | |
| 'task': { | |
| 'train_data': { | |
| 'tf_data_service_address': self._flags_obj.tf_data_service, | |
| }, | |
| 'validation_data': { | |
| 'tf_data_service_address': self._flags_obj.tf_data_service, | |
| } | |
| } | |
| }) | |
| return params | |
| def parse_params_override(self, params): | |
| # Get the second level of override from `--params_override`. | |
| # `--params_override` is typically used as a further override over the | |
| # template. For example, one may define a particular template for training | |
| # ResNet50 on ImageNet in a config file and pass it via `--config_file`, | |
| # then define different learning rates and pass it via `--params_override`. | |
| if self._flags_obj.params_override: | |
| params = hyperparams.override_params_dict( | |
| params, self._flags_obj.params_override, is_strict=True) | |
| return params | |
| def parse_configuration(flags_obj, lock_return=True, print_return=True): | |
| """Parses ExperimentConfig from flags.""" | |
| params = ExperimentParser(flags_obj).parse() | |
| params.validate() | |
| if lock_return: | |
| params.lock() | |
| if print_return: | |
| pp = pprint.PrettyPrinter() | |
| logging.info('Final experiment parameters:\n%s', | |
| pp.pformat(params.as_dict())) | |
| return params | |
| def serialize_config(params: config_definitions.ExperimentConfig, | |
| model_dir: str): | |
| """Serializes and saves the experiment config.""" | |
| if model_dir is None: | |
| raise ValueError('model_dir must be specified, but got None') | |
| params_save_path = os.path.join(model_dir, 'params.yaml') | |
| logging.info('Saving experiment configuration to %s', params_save_path) | |
| tf.io.gfile.makedirs(model_dir) | |
| hyperparams.save_params_dict_to_yaml(params, params_save_path) | |
| def save_gin_config(filename_suffix: str, model_dir: str): | |
| """Serializes and saves the experiment config.""" | |
| gin_save_path = os.path.join( | |
| model_dir, 'operative_config.{}.gin'.format(filename_suffix)) | |
| logging.info('Saving gin configurations to %s', gin_save_path) | |
| tf.io.gfile.makedirs(model_dir) | |
| with tf.io.gfile.GFile(gin_save_path, 'w') as f: | |
| f.write(gin.operative_config_str()) | |
| def read_global_step_from_checkpoint(ckpt_file_path): | |
| """Read global step from checkpoint, or get global step from its filename.""" | |
| global_step = tf.Variable(-1, dtype=tf.int64) | |
| ckpt = tf.train.Checkpoint(global_step=global_step) | |
| try: | |
| ckpt.restore(ckpt_file_path).expect_partial() | |
| global_step_maybe_restored = global_step.numpy() | |
| except tf.errors.InvalidArgumentError: | |
| global_step_maybe_restored = -1 | |
| if global_step_maybe_restored == -1: | |
| raise ValueError('global_step not found in checkpoint {}. ' | |
| 'If you want to run finetune eval jobs, you need to ' | |
| 'make sure that your pretrain model writes ' | |
| 'global_step in its checkpoints.'.format(ckpt_file_path)) | |
| global_step_restored = global_step.numpy() | |
| logging.info('get global_step %d from checkpoint %s', global_step_restored, | |
| ckpt_file_path) | |
| return global_step_restored | |
| def write_json_summary(log_dir, global_step, eval_metrics): | |
| """Dump evaluation metrics to json file.""" | |
| serializable_dict = {} | |
| for name, value in eval_metrics.items(): | |
| if hasattr(value, 'numpy'): | |
| serializable_dict[name] = str(value.numpy()) | |
| else: | |
| serializable_dict[name] = str(value) | |
| output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step)) | |
| logging.info('Evaluation results at pretrain step %d: %s', global_step, | |
| serializable_dict) | |
| with tf.io.gfile.GFile(output_json, 'w') as writer: | |
| writer.write(json.dumps(serializable_dict, indent=4) + '\n') | |
| def write_summary(summary_writer, global_step, eval_metrics): | |
| """Write evaluation metrics to TF summary.""" | |
| numeric_dict = {} | |
| for name, value in eval_metrics.items(): | |
| numeric_dict[name] = float(orbit.utils.get_value(value)) | |
| with summary_writer.as_default(): | |
| for name, value in numeric_dict.items(): | |
| tf.summary.scalar(name, value, step=global_step) | |
| summary_writer.flush() | |
| def remove_ckpts(model_dir): | |
| """Remove model checkpoints, so we can restart.""" | |
| ckpts = os.path.join(model_dir, 'ckpt-*') | |
| logging.info('removing checkpoint files %s', ckpts) | |
| for file_to_remove in tf.io.gfile.glob(ckpts): | |
| tf.io.gfile.rmtree(file_to_remove) | |
| file_to_remove = os.path.join(model_dir, 'checkpoint') | |
| if tf.io.gfile.exists(file_to_remove): | |
| tf.io.gfile.remove(file_to_remove) | |
| def write_model_params(model: Union[tf.Module, tf_keras.Model], | |
| output_path: str) -> None: | |
| """Writes the model parameters and shapes to a file. | |
| Args: | |
| model: A model instance. | |
| output_path: Output file path. | |
| """ | |
| with tf.io.gfile.GFile(output_path, 'w') as f: | |
| total_params = 0 | |
| for var in model.variables: | |
| shape = tf.shape(var) | |
| total_params += tf.math.reduce_prod(shape).numpy() | |
| f.write(f'{var.name} {shape.numpy().tolist()}\n') | |
| f.write(f'\nTotal params: {total_params}\n') | |
| def try_count_params( | |
| model: Union[tf.Module, tf_keras.Model], | |
| trainable_only: bool = False): | |
| """Count the number of parameters if model is possible. | |
| Args: | |
| model: Try to count the number of params in this model. | |
| trainable_only: Whether to calculate trainable params only. This flag is | |
| not used when the model has `count_params` attribute. | |
| Returns: | |
| The number of parameters or None. | |
| """ | |
| if hasattr(model, 'count_params'): | |
| try: | |
| return model.count_params() | |
| except ValueError: | |
| logging.info('Number of trainable params unknown, because the build() ' | |
| 'methods in keras layers were not called. This is probably ' | |
| 'because the model was not feed any input, e.g., the max ' | |
| 'train step already reached before this run.') | |
| return None | |
| else: | |
| total_params = 0 | |
| variables = model.trainable_variables if trainable_only else model.variables | |
| for var in variables: | |
| shape = tf.shape(var) | |
| total_params += tf.math.reduce_prod(shape).numpy() | |
| return total_params | |
| def try_count_flops(model: Union[tf.Module, tf_keras.Model], | |
| inputs_kwargs: Optional[Dict[str, Any]] = None, | |
| output_path: Optional[str] = None): | |
| """Counts and returns model FLOPs. | |
| Args: | |
| model: A model instance. | |
| inputs_kwargs: An optional dictionary of argument pairs specifying inputs' | |
| shape specifications to getting corresponding concrete function. | |
| output_path: A file path to write the profiling results to. | |
| Returns: | |
| The model's FLOPs. | |
| """ | |
| if hasattr(model, 'inputs'): | |
| try: | |
| # Get input shape and set batch size to 1. | |
| if model.inputs: | |
| inputs = [ | |
| tf.TensorSpec([1] + input.shape[1:], input.dtype) | |
| for input in model.inputs | |
| ] | |
| concrete_func = tf.function(model).get_concrete_function(inputs) | |
| # If model.inputs is invalid, try to use the input to get concrete | |
| # function for model.call (subclass model). | |
| else: | |
| concrete_func = tf.function(model.call).get_concrete_function( | |
| **inputs_kwargs) | |
| frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func) | |
| # Calculate FLOPs. | |
| run_meta = tf.compat.v1.RunMetadata() | |
| opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() | |
| if output_path is not None: | |
| opts['output'] = f'file:outfile={output_path}' | |
| else: | |
| opts['output'] = 'none' | |
| flops = tf.compat.v1.profiler.profile( | |
| graph=frozen_func.graph, run_meta=run_meta, options=opts) | |
| return flops.total_float_ops | |
| except Exception as e: # pylint: disable=broad-except | |
| logging.info( | |
| 'Failed to count model FLOPs with error %s, because the build() ' | |
| 'methods in keras layers were not called. This is probably because ' | |
| 'the model was not feed any input, e.g., the max train step already ' | |
| 'reached before this run.', e) | |
| return None | |
| return None | |
| def _einsum_flops(graph, node): | |
| """Calculates the compute resources needed for Einsum.""" | |
| assert len(node.input) == 2 | |
| x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name( | |
| graph, node.input[0]) | |
| y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name( | |
| graph, node.input[1]) | |
| x_shape.assert_is_fully_defined() | |
| y_shape.assert_is_fully_defined() | |
| x_shape = x_shape.as_list() | |
| y_shape = y_shape.as_list() | |
| equation = str(node.attr['equation']) | |
| equation = ( | |
| equation.replace('s:', '') | |
| .replace('"', '') | |
| .replace(' ', '') | |
| .replace('\n', '') | |
| ) | |
| x_str = equation.split(',')[0] | |
| y_r_str = equation.split(',')[1] | |
| y_str = y_r_str.split('->')[0] | |
| r_str = y_r_str.split('->')[1] | |
| shape_dic = {} | |
| contracted = set() | |
| for indice in x_str + y_str: | |
| if indice in x_str: | |
| indice_dim = x_shape[x_str.find(indice)] | |
| elif indice in y_str: | |
| indice_dim = y_shape[y_str.find(indice)] | |
| else: | |
| raise ValueError('indice {} not found in inputs'.format(indice)) | |
| shape_dic[indice] = indice_dim | |
| if indice not in r_str: | |
| contracted.add(indice) | |
| madds = np.prod([shape_dic[indice] for indice in r_str]) * ( | |
| np.prod([shape_dic[indice] for indice in contracted])) | |
| flops = 2 * madds | |
| return ops.OpStats('flops', flops) | |