| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Utilities for GER training.""" |
|
|
| import copy |
| from typing import Any |
|
|
| from absl import logging |
| import flax |
| from flax import jax_utils |
| from flax.training import checkpoints |
| import jax |
| import jax.numpy as jnp |
| import ml_collections |
| import numpy as np |
| import optax |
| from scenic.dataset_lib import dataset_utils |
| from scenic.train_lib import lr_schedules |
| from scenic.train_lib import optimizers as optimizer_lib |
|
|
| from tensorflow.io import gfile |
|
|
| PyTree = Any |
|
|
|
|
| class EntityIds2Code(): |
| """Quantization with given token ids at initialization.""" |
|
|
| def __init__(self, config: ml_collections.ConfigDict): |
| """Entity id to code.""" |
| self.config = config |
| self.bos = config.get('ger_bos', 101) |
| if self.config.get('load_codes_from'): |
| logging.info('Loading all codes from: %s', config.load_codes_from) |
| with gfile.Open(config.load_codes_from, 'rb') as f: |
| codes = np.load(f) |
| else: |
| |
| logging.info('Codes not found --> we use from randomly atomic ids.') |
| np.random.seed(config.get('seed', 0)) |
| ne = config.get('n_entities', 6084491) |
| nq = config.code_length |
| codes = np.random.choice(config.vocab_size, ne * nq,).reshape((ne, nq)) |
| self.codes = jnp.array(codes.astype(np.int32)) |
|
|
| def __call__( |
| self, inputs: jax.Array, train: bool = False, |
| debug: bool = False) -> jax.Array: |
| del debug, train |
| tokens = self.encode_to_indices(inputs) |
| |
| tokens = tokens + 2 |
| |
| b = tokens.shape[0] |
| tokens = jnp.concatenate( |
| [self.bos * jnp.ones((b, 1)), tokens], axis=-1).astype('int32') |
| return jax.lax.stop_gradient(tokens) |
|
|
| def encode_to_indices(self, inputs: jax.Array) -> jax.Array: |
| return self.codes[inputs] |
|
|
|
|
| def get_code2id(entity_codes): |
| """Gets a code to entity id mapping.""" |
| code2id = {} |
| entity_codes += 2 |
| for i, code in enumerate(entity_codes): |
| code_str = '-'.join([str(int(c))for c in code]) |
| code2id[code_str] = i |
| return code2id |
|
|
|
|
| def load_weights(train_state, config): |
| """Load pretrained weights or checkpoint. |
| |
| Args: |
| train_state: the parameters that need to be restored. |
| config: config dict that should contain "weights": the path of the |
| checkpoint. |
| Returns: |
| train_state: restored train_state. |
| start_step: step number of the checkpoint. |
| """ |
| start_step = 0 |
| weight_path = config.get('weights', '') |
| skip_wrong_shape = config.get('skip_wrong_shape', False) |
| load_prefix = config.get('load_prefix', '') |
| ignored_keys = config.get('ignored_keys', '') |
| if weight_path: |
| logging.info('Loading weights from %s', weight_path) |
| weight_data = checkpoints.restore_checkpoint(weight_path, None) |
| if 'params' in weight_data: |
| restored_params = weight_data['params'] |
| else: |
| |
| restored_params = weight_data['optimizer']['target'] |
| if 'params' in restored_params: |
| restored_params = restored_params['params'] |
|
|
| expected_params = train_state.params.unfreeze() |
| flattened_restored_params = flax.traverse_util.flatten_dict( |
| restored_params, sep='/') |
| if load_prefix: |
| flattened_restored_params = { |
| load_prefix + k: v for k, v in flattened_restored_params.items()} |
| flattened_expected_params = flax.traverse_util.flatten_dict( |
| expected_params, sep='/') |
| extra_keys = flattened_restored_params.keys( |
| ) - flattened_expected_params.keys() |
| missing_keys = flattened_expected_params.keys( |
| ) - flattened_restored_params.keys() |
| logging.info('Inspect extra keys:%s', extra_keys) |
| logging.info('Inspect missing keys:%s', missing_keys) |
| for k, v in flattened_restored_params.items(): |
| if ignored_keys and k.startswith(ignored_keys): |
| logging.info('Skipping parameter %s because it starts with %s.', k, |
| ignored_keys) |
| continue |
| if k not in flattened_expected_params: |
| logging.info( |
| 'Skipping parameter %s in restored model, but not in target.', k) |
| continue |
| if flattened_expected_params[k].shape != v.shape: |
| logging.info( |
| 'Key: %s. Expected shape: %s. Restored shape: %s', k, |
| flattened_expected_params[k].shape, v.shape) |
| if not skip_wrong_shape: |
| assert ValueError( |
| 'Shape mismatch between restored and target model' |
| 'Set config.skip_wrong_shape = True if this is expected.') |
| else: |
| flattened_expected_params[k] = v |
| new_params = flax.traverse_util.unflatten_dict( |
| flattened_expected_params, sep='/') |
| train_state = train_state.replace(params=flax.core.FrozenDict(new_params)) |
| return train_state, start_step |
|
|
|
|
| def optimizer_with_decoder_multiplier( |
| config: ml_collections.ConfigDict, |
| params: PyTree, |
| use_frozen_params: bool = True): |
| """Returns an optimizer with decoder learning rate multiplier. |
| |
| |
| Args: |
| config: The training config. |
| params: The parameters of the model being trained. |
| use_frozen_params: If True, the optimizer will always expect to receive |
| a FrozenDict of parameters and gradients. |
| |
| Returns: |
| An Optax optimizer. |
| """ |
| optimizer_config = config.optimizer |
| |
| optimizer_config = copy.deepcopy(optimizer_config).unlock() |
| base_learning_rate = config.lr_configs.base_learning_rate |
|
|
| decoder_layer_prefix = optimizer_config.decoder_layer_prefix |
| decoder_multiplier = optimizer_config.decoder_multiplier |
| decoder_learning_rate = base_learning_rate * decoder_multiplier |
| del optimizer_config.decoder_layer_prefix |
| del optimizer_config.decoder_multiplier |
| logging.info('Learning rate scales: %s', decoder_learning_rate) |
|
|
| decoder_config = copy.deepcopy(config) |
| decoder_config.lr_configs.base_learning_rate = decoder_learning_rate |
|
|
| learning_rate_fns = lr_schedules.get_learning_rate_fn(config) |
| decoder_learning_rate_fns = lr_schedules.get_learning_rate_fn( |
| decoder_config) |
|
|
| optimizers = { |
| False: optimizer_lib.get_optimizer( |
| optimizer_config, learning_rate_fns, params), |
| True: optimizer_lib.get_optimizer( |
| optimizer_config, decoder_learning_rate_fns, params), |
| } |
|
|
| def is_decoder(name: str) -> bool: |
| return name.startswith(decoder_layer_prefix) |
|
|
| flat_params = flax.traverse_util.flatten_dict( |
| flax.core.unfreeze(params), keep_empty_nodes=True, sep='/') |
| flat_layer_map = {k: is_decoder(k) for k in flat_params} |
| layer_map = flax.traverse_util.unflatten_dict(flat_layer_map, sep='/') |
| if use_frozen_params: |
| layer_map = flax.core.freeze(layer_map) |
|
|
| logging.info( |
| 'Layer assignments:\n%s', |
| flax.traverse_util.flatten_dict(layer_map, sep='/')) |
| tx = optax.multi_transform(optimizers, layer_map) |
| return tx |
|
|
|
|
| def to_cpu(array: jnp.ndarray): |
| """Transfers array (replicated on multiple hosts) to a single host. |
| |
| Args: |
| array: Replicated array of shape |
| [num_hosts, num_devices, local_batch_size, ...]. |
| |
| Returns: |
| array of shape [global_batch_size, ...] where |
| global_batch_size = num_devices * local_batch_size |
| """ |
| return jax.device_get(dataset_utils.unshard(jax_utils.unreplicate(array))) |
|
|