File size: 8,004 Bytes
1327f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for GER training."""

import copy
from typing import Any

from absl import logging
import flax
from flax import jax_utils
from flax.training import checkpoints
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
from scenic.dataset_lib import dataset_utils
from scenic.train_lib import lr_schedules
from scenic.train_lib import optimizers as optimizer_lib

from tensorflow.io import gfile

PyTree = Any  # JAX team is working on type annotation for pytree:


class EntityIds2Code():
  """Quantization with given token ids at initialization."""

  def __init__(self, config: ml_collections.ConfigDict):
    """Entity id to code."""
    self.config = config
    self.bos = config.get('ger_bos', 101)
    if self.config.get('load_codes_from'):
      logging.info('Loading all codes from: %s', config.load_codes_from)
      with gfile.Open(config.load_codes_from, 'rb') as f:
        codes = np.load(f)
    else:
      # If we don't find a code file to start from we simply use random codes.
      logging.info('Codes not found --> we use from randomly atomic ids.')
      np.random.seed(config.get('seed', 0))
      ne = config.get('n_entities', 6084491)
      nq = config.code_length
      codes = np.random.choice(config.vocab_size, ne * nq,).reshape((ne, nq))
    self.codes = jnp.array(codes.astype(np.int32))

  def __call__(
      self, inputs: jax.Array, train: bool = False,
      debug: bool = False) -> jax.Array:
    del debug, train
    tokens = self.encode_to_indices(inputs)
    # We add two to the vocabulary: sos and eos
    tokens = tokens + 2
    # We shift right. <SOS> is 0.
    b = tokens.shape[0]
    tokens = jnp.concatenate(
        [self.bos * jnp.ones((b, 1)), tokens], axis=-1).astype('int32')
    return jax.lax.stop_gradient(tokens)

  def encode_to_indices(self, inputs: jax.Array) -> jax.Array:
    return self.codes[inputs]


def get_code2id(entity_codes):
  """Gets a code to entity id mapping."""
  code2id = {}
  entity_codes += 2
  for i, code in enumerate(entity_codes):
    code_str = '-'.join([str(int(c))for c in code])
    code2id[code_str] = i
  return code2id


def load_weights(train_state, config):
  """Load pretrained weights or checkpoint.

  Args:
    train_state: the parameters that need to be restored.
    config: config dict that should contain "weights": the path of the
      checkpoint.
  Returns:
    train_state: restored train_state.
    start_step: step number of the checkpoint.
  """
  start_step = 0
  weight_path = config.get('weights', '')
  skip_wrong_shape = config.get('skip_wrong_shape', False)
  load_prefix = config.get('load_prefix', '')
  ignored_keys = config.get('ignored_keys', '')
  if weight_path:
    logging.info('Loading weights from %s', weight_path)
    weight_data = checkpoints.restore_checkpoint(weight_path, None)
    if 'params' in weight_data:
      restored_params = weight_data['params']
    else:
      # Old Scenic train state format.
      restored_params = weight_data['optimizer']['target']
      if 'params' in restored_params:  # Backward compatibility.
        restored_params = restored_params['params']

    expected_params = train_state.params.unfreeze()
    flattened_restored_params = flax.traverse_util.flatten_dict(
        restored_params, sep='/')
    if load_prefix:
      flattened_restored_params = {
          load_prefix + k: v for k, v in flattened_restored_params.items()}
    flattened_expected_params = flax.traverse_util.flatten_dict(
        expected_params, sep='/')
    extra_keys = flattened_restored_params.keys(
        ) - flattened_expected_params.keys()
    missing_keys = flattened_expected_params.keys(
        ) - flattened_restored_params.keys()
    logging.info('Inspect extra keys:%s', extra_keys)
    logging.info('Inspect missing keys:%s', missing_keys)
    for k, v in flattened_restored_params.items():
      if ignored_keys and k.startswith(ignored_keys):
        logging.info('Skipping parameter %s because it starts with %s.', k,
                     ignored_keys)
        continue
      if k not in flattened_expected_params:
        logging.info(
            'Skipping parameter %s in restored model, but not in target.', k)
        continue
      if flattened_expected_params[k].shape != v.shape:
        logging.info(
            'Key: %s. Expected shape: %s. Restored shape: %s', k,
            flattened_expected_params[k].shape, v.shape)
        if not skip_wrong_shape:
          assert ValueError(
              'Shape mismatch between restored and target model'
              'Set config.skip_wrong_shape = True if this is expected.')
      else:
        flattened_expected_params[k] = v
    new_params = flax.traverse_util.unflatten_dict(
        flattened_expected_params, sep='/')
    train_state = train_state.replace(params=flax.core.FrozenDict(new_params))
  return train_state, start_step


def optimizer_with_decoder_multiplier(
    config: ml_collections.ConfigDict,
    params: PyTree,
    use_frozen_params: bool = True):
  """Returns an optimizer with decoder learning rate multiplier.


  Args:
    config: The training config.
    params: The parameters of the model being trained.
    use_frozen_params: If True, the optimizer will always expect to receive
      a FrozenDict of parameters and gradients.

  Returns:
    An Optax optimizer.
  """
  optimizer_config = config.optimizer
  # Avoid modifying original config and allow alteration.
  optimizer_config = copy.deepcopy(optimizer_config).unlock()
  base_learning_rate = config.lr_configs.base_learning_rate

  decoder_layer_prefix = optimizer_config.decoder_layer_prefix
  decoder_multiplier = optimizer_config.decoder_multiplier
  decoder_learning_rate = base_learning_rate * decoder_multiplier
  del optimizer_config.decoder_layer_prefix
  del optimizer_config.decoder_multiplier
  logging.info('Learning rate scales: %s', decoder_learning_rate)

  decoder_config = copy.deepcopy(config)
  decoder_config.lr_configs.base_learning_rate = decoder_learning_rate

  learning_rate_fns = lr_schedules.get_learning_rate_fn(config)
  decoder_learning_rate_fns = lr_schedules.get_learning_rate_fn(
      decoder_config)

  optimizers = {
      False: optimizer_lib.get_optimizer(  # not decoder
          optimizer_config, learning_rate_fns, params),
      True: optimizer_lib.get_optimizer(  # is decoder
          optimizer_config, decoder_learning_rate_fns, params),
  }

  def is_decoder(name: str) -> bool:
    return name.startswith(decoder_layer_prefix)

  flat_params = flax.traverse_util.flatten_dict(
      flax.core.unfreeze(params), keep_empty_nodes=True, sep='/')
  flat_layer_map = {k: is_decoder(k) for k in flat_params}
  layer_map = flax.traverse_util.unflatten_dict(flat_layer_map, sep='/')
  if use_frozen_params:
    layer_map = flax.core.freeze(layer_map)

  logging.info(
      'Layer assignments:\n%s',
      flax.traverse_util.flatten_dict(layer_map, sep='/'))
  tx = optax.multi_transform(optimizers, layer_map)
  return tx


def to_cpu(array: jnp.ndarray):
  """Transfers array (replicated on multiple hosts) to a single host.

  Args:
    array: Replicated array of shape
      [num_hosts, num_devices, local_batch_size, ...].

  Returns:
    array of shape [global_batch_size, ...] where
      global_batch_size = num_devices * local_batch_size
  """
  return jax.device_get(dataset_utils.unshard(jax_utils.unreplicate(array)))