|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model utils for MBT.""" |
|
|
|
|
|
from typing import Any |
|
|
|
|
|
from absl import logging |
|
|
import flax |
|
|
import jax.numpy as jnp |
|
|
import ml_collections |
|
|
import numpy as np |
|
|
from scenic.common_lib import debug_utils |
|
|
from scenic.projects.vivit import model_utils as vivit_utils |
|
|
import scipy |
|
|
|
|
|
|
|
|
|
|
|
central_frame_initializer = vivit_utils.central_frame_initializer |
|
|
average_frame_initializer = vivit_utils.average_frame_initializer |
|
|
tile_positional_embeddings = vivit_utils.tile_positional_embeddings |
|
|
|
|
|
|
|
|
def interpolate_positional_embeddings(restored_posemb_grid, n_tokens): |
|
|
"""Interpolate positional embeddings from one size to another. |
|
|
|
|
|
Args: |
|
|
restored_posemb_grid: Positional embeddings from restored model. Shape is |
|
|
[n_restored_tokens, d]. It is assumed that the restored model used square |
|
|
image patches. |
|
|
n_tokens: Number of tokens in the target model. Can be a scalar if the |
|
|
target image is square, otherwise should be a tuple of 2. |
|
|
|
|
|
Returns: |
|
|
positional embedding resized to match n_tokens. Shape is [1, n_tokens, d] |
|
|
""" |
|
|
|
|
|
restored_gs = int(np.sqrt(len(restored_posemb_grid))) |
|
|
if isinstance(n_tokens, tuple): |
|
|
gh, gw = n_tokens |
|
|
else: |
|
|
if n_tokens == len(restored_posemb_grid): |
|
|
|
|
|
return np.expand_dims(restored_posemb_grid, axis=0) |
|
|
gh = int(np.sqrt(n_tokens)) |
|
|
gw = n_tokens // gh |
|
|
assert gh * gw == n_tokens |
|
|
logging.info('Resizing grid-size from (%s, %s) to (%s, %s).', |
|
|
restored_gs, restored_gs, gh, gw) |
|
|
restored_posemb_grid = restored_posemb_grid.reshape(restored_gs, restored_gs, |
|
|
-1) |
|
|
zoom = (gh / restored_gs, gw / restored_gs, 1) |
|
|
restored_posemb_grid = scipy.ndimage.zoom(restored_posemb_grid, zoom, order=1) |
|
|
restored_posemb_grid = restored_posemb_grid.reshape(1, gh * gw, -1) |
|
|
return restored_posemb_grid |
|
|
|
|
|
|
|
|
def initialise_from_train_state( |
|
|
config, |
|
|
train_state: Any, |
|
|
restored_train_state: Any, |
|
|
restored_model_cfg: ml_collections.ConfigDict, |
|
|
restore_output_proj: bool, |
|
|
mbt_transformer_key: str = 'Transformer', |
|
|
log_initialised_param_shapes: bool = True, |
|
|
one_config: bool = True, |
|
|
prefix_path: Any = None) -> Any: |
|
|
"""Updates the train_state with data from restored_train_state. |
|
|
|
|
|
This function is written to be used for 'fine-tuning' experiments. Here, we |
|
|
do some surgery to support larger resolutions (longer sequence length) in |
|
|
the transformer block, with respect to the learned pos-embeddings. |
|
|
|
|
|
Args: |
|
|
config: Configurations for the model being updated, or tuple of configs. |
|
|
train_state: A raw TrainState for the model. |
|
|
restored_train_state: A TrainState that is loaded with parameters/state of a |
|
|
pretrained model. |
|
|
restored_model_cfg: Configuration of the model from which the |
|
|
restored_train_state come from. Usually used for some asserts. |
|
|
restore_output_proj: If true, load the final output projection. Set |
|
|
to False if finetuning to a new dataset. |
|
|
mbt_transformer_key: The key used for storing the subtree in the |
|
|
parameters that keeps Transformer weights, that are supposed to be |
|
|
initialized from the given pre-trained model. |
|
|
log_initialised_param_shapes: If true, print tabular summary of all the |
|
|
variables in the model once they have been initialised. |
|
|
one_config: If true, we have only a single config. If false, we get a tuple |
|
|
of configs in the order [init_config, model_config, dataset_config]. This |
|
|
is useful for works that build upon MBT and have different models in their |
|
|
config. |
|
|
prefix_path: If parameters are in a subtree. |
|
|
|
|
|
Returns: |
|
|
Updated train_state. |
|
|
""" |
|
|
|
|
|
if one_config: |
|
|
init_config = config.init_from |
|
|
model_config = config.model |
|
|
dataset_config = config.dataset_configs |
|
|
else: |
|
|
init_config, model_config, dataset_config = config |
|
|
|
|
|
|
|
|
params = flax.core.unfreeze(train_state.optimizer.target) |
|
|
logging.info('Parameters in the target model are: %s', params) |
|
|
|
|
|
if init_config.get('checkpoint_format', 'scenic') == 'big_vision': |
|
|
restored_params = restored_train_state.optimizer['target'] |
|
|
else: |
|
|
restored_params = restored_train_state.optimizer.target |
|
|
restored_params = flax.core.unfreeze(restored_params) |
|
|
if init_config.get('init_from_vit', True): |
|
|
if prefix_path: |
|
|
video_params = params[prefix_path] |
|
|
else: |
|
|
video_params = params |
|
|
|
|
|
for m_key, m_params in restored_params.items(): |
|
|
if m_key == 'output_projection': |
|
|
if restore_output_proj: |
|
|
video_params[m_key] = m_params |
|
|
else: |
|
|
pass |
|
|
elif m_key == 'pre_logits': |
|
|
if model_config.representation_size is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
video_params.pop(m_key, None) |
|
|
else: |
|
|
assert restored_model_cfg.model.representation_size |
|
|
video_params[m_key] = m_params |
|
|
|
|
|
elif m_key in ['Transformer']: |
|
|
for tm_key, tm_params in m_params.items(): |
|
|
if tm_key == 'posembed_input': |
|
|
init_posemb( |
|
|
video_params[mbt_transformer_key], |
|
|
m_params, |
|
|
init_config, |
|
|
model_config, |
|
|
dataset_config, |
|
|
restored_model_cfg, |
|
|
'posembed_input', |
|
|
prefix_path=prefix_path) |
|
|
init_posemb( |
|
|
video_params[mbt_transformer_key], |
|
|
m_params, |
|
|
init_config, |
|
|
model_config, |
|
|
dataset_config, |
|
|
restored_model_cfg, |
|
|
'posembed_input_spectrogram', |
|
|
prefix_path=prefix_path) |
|
|
init_posemb( |
|
|
video_params, |
|
|
m_params, |
|
|
init_config, |
|
|
model_config, |
|
|
dataset_config, |
|
|
restored_model_cfg, |
|
|
'bottleneck', |
|
|
prefix_path=prefix_path) |
|
|
elif 'encoderblock' in tm_key: |
|
|
logging.info('Loading encoder parameters.') |
|
|
init_encoderblock(video_params[mbt_transformer_key], m_params, |
|
|
tm_key) |
|
|
else: |
|
|
video_params[mbt_transformer_key][tm_key] = tm_params |
|
|
elif m_key == 'embedding': |
|
|
init_embedding(video_params, m_params, init_config, model_config, |
|
|
'embedding') |
|
|
init_embedding(video_params, m_params, init_config, model_config, |
|
|
'embedding_spectrogram') |
|
|
else: |
|
|
if m_key in train_state.optimizer.target: |
|
|
video_params[m_key] = m_params |
|
|
if '%s_spectrogram' % m_key in train_state.optimizer.target: |
|
|
video_params['%s_spectrogram' % m_key] = m_params |
|
|
else: |
|
|
logging.info('Skipping %s. In restored model but not in target', |
|
|
m_key) |
|
|
else: |
|
|
for m_key, m_params in restored_params.items(): |
|
|
if m_key == 'output_projection': |
|
|
if restore_output_proj: |
|
|
params[m_key] = m_params |
|
|
else: |
|
|
pass |
|
|
elif m_key == 'pre_logits': |
|
|
if model_config.representation_size is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params.pop(m_key, None) |
|
|
else: |
|
|
assert restored_model_cfg.model.representation_size |
|
|
params[m_key] = m_params |
|
|
else: |
|
|
if m_key in train_state.optimizer.target: |
|
|
params[m_key] = m_params |
|
|
else: |
|
|
logging.info('Skipping %s. In restored model but not in target', |
|
|
m_key) |
|
|
|
|
|
if log_initialised_param_shapes: |
|
|
logging.info('Parameter summary after initialising from train state') |
|
|
debug_utils.log_param_shapes(params) |
|
|
return train_state.replace( |
|
|
optimizer=train_state.optimizer.replace(target=flax.core.freeze(params))) |
|
|
|
|
|
|
|
|
def init_posemb(to_params, from_params, init_config, model_config, |
|
|
dataset_config, restored_model_cfg, name, prefix_path=None): |
|
|
"""Initialize the positional embeddings.""" |
|
|
if name not in to_params: |
|
|
logging.info('No %s in target model', name) |
|
|
elif init_config.restore_positional_embedding: |
|
|
if name == 'bottleneck': |
|
|
posemb = to_params[name] |
|
|
else: |
|
|
posemb = to_params[name]['pos_embedding'] |
|
|
restored_posemb = from_params['posembed_input']['pos_embedding'] |
|
|
if restored_posemb.shape != posemb.shape: |
|
|
|
|
|
|
|
|
logging.info('Adapting positional embeddings %s from %s to %s', |
|
|
name, restored_posemb.shape, posemb.shape) |
|
|
ntok = posemb.shape[1] |
|
|
if prefix_path: |
|
|
|
|
|
classifier = restored_model_cfg.mbt.model.classifier |
|
|
else: |
|
|
classifier = restored_model_cfg.model.classifier |
|
|
if classifier == 'token': |
|
|
|
|
|
cls_tok = restored_posemb[:, :1] |
|
|
restored_posemb_grid = restored_posemb[0, 1:] |
|
|
else: |
|
|
cls_tok = restored_posemb[:, :0] |
|
|
restored_posemb_grid = restored_posemb[0] |
|
|
if model_config.classifier == 'token': |
|
|
ntok -= 1 |
|
|
|
|
|
size_change = init_config.positional_embed_size_change |
|
|
if name == 'bottleneck': |
|
|
restored_posemb_grid = interpolate_positional_embeddings( |
|
|
restored_posemb_grid, ntok) |
|
|
elif size_change == 'tile': |
|
|
restored_posemb_grid = tile_positional_embeddings( |
|
|
restored_posemb_grid, ntok) |
|
|
elif size_change in ['resize_tile', 'resize']: |
|
|
temp_encoding = model_config.temporal_encoding_config |
|
|
if name.find('spectrogram') > -1: |
|
|
gh = ((dataset_config.spec_shape[0] * |
|
|
dataset_config.num_spec_frames) // |
|
|
model_config.patches.size[0]) |
|
|
gw = (dataset_config.spec_shape[1] // |
|
|
model_config.patches.size[1]) |
|
|
tokens_per_frame = (gh, gw) |
|
|
elif temp_encoding.method == 'temporal_sampling': |
|
|
tokens_per_frame = int(ntok / temp_encoding.n_sampled_frames) |
|
|
elif temp_encoding.method == '3d_conv': |
|
|
|
|
|
n_frames = ( |
|
|
dataset_config.num_frames // |
|
|
model_config.patches.size[2]) |
|
|
tokens_per_frame = ntok // n_frames |
|
|
else: |
|
|
raise AssertionError( |
|
|
f'Unknown temporal encoding {temp_encoding.method}') |
|
|
|
|
|
restored_posemb_grid = interpolate_positional_embeddings( |
|
|
restored_posemb_grid, tokens_per_frame) |
|
|
if size_change == 'resize_tile' and ntok != tokens_per_frame: |
|
|
restored_posemb_grid = restored_posemb_grid[0] |
|
|
restored_posemb_grid = tile_positional_embeddings( |
|
|
restored_posemb_grid, ntok) |
|
|
else: |
|
|
raise AssertionError( |
|
|
'Unknown positional embedding size changing method') |
|
|
|
|
|
if model_config.classifier == 'token': |
|
|
restored_posemb = jnp.array( |
|
|
np.concatenate([cls_tok, restored_posemb_grid], axis=1)) |
|
|
else: |
|
|
restored_posemb = restored_posemb_grid |
|
|
|
|
|
if name == 'bottleneck': |
|
|
to_params[name] = restored_posemb |
|
|
else: |
|
|
to_params[name]['pos_embedding'] = restored_posemb |
|
|
else: |
|
|
logging.info('Not restoring positional encodings from pretrained model') |
|
|
|
|
|
|
|
|
def init_embedding(to_params, from_params, init_config, model_config, name): |
|
|
"""Initialize input embedding.""" |
|
|
if name not in to_params: |
|
|
logging.info('No %s in target model', name) |
|
|
elif init_config.get('restore_input_embedding', True): |
|
|
input_kernel = to_params[name]['kernel'] |
|
|
restored_kernel = from_params['kernel'] |
|
|
restored_bias = from_params['bias'] |
|
|
|
|
|
if input_kernel.shape != restored_kernel.shape: |
|
|
kernel_init_method = model_config.temporal_encoding_config.kernel_init_method |
|
|
if input_kernel.shape == restored_kernel.shape[1:]: |
|
|
|
|
|
restored_kernel = np.mean(restored_kernel, axis=0) |
|
|
elif input_kernel.shape[1:] != restored_kernel.shape: |
|
|
|
|
|
restored_kernel = np.reshape(restored_kernel, input_kernel.shape) |
|
|
elif input_kernel.shape[0] == 1: |
|
|
|
|
|
restored_kernel = np.expand_dims(restored_kernel, axis=0) |
|
|
elif kernel_init_method == 'average_frame_initializer': |
|
|
|
|
|
|
|
|
|
|
|
logging.info('Initializing input kernel with filter inflation.') |
|
|
t = input_kernel.shape[0] |
|
|
restored_kernel = np.expand_dims(restored_kernel, axis=0) |
|
|
restored_kernel = np.tile(restored_kernel, [t, 1, 1, 1, 1]) / t |
|
|
elif kernel_init_method == 'average_arp_frame_initializer': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.info('Initialzing input kernel with ARP inflation') |
|
|
t = input_kernel.shape[0] |
|
|
restored_kernel = np.expand_dims(restored_kernel, axis=0) |
|
|
restored_kernel = np.tile(restored_kernel, [t, 1, 1, 1, 1]) |
|
|
|
|
|
def average_arp(length): |
|
|
|
|
|
array = np.arange(1, length + 1) |
|
|
|
|
|
harmonic = np.zeros((length + 1)) |
|
|
harmonic[1:] = np.cumsum(1.0 / array) |
|
|
|
|
|
array = 2 * (length - array + 1) - (length + 1) * ( |
|
|
harmonic[-1] - harmonic[:-1]) |
|
|
return array |
|
|
|
|
|
normalizer = average_arp(t) / t |
|
|
normalizer = np.reshape(normalizer, [t, 1, 1, 1, 1]) |
|
|
restored_kernel = restored_kernel * normalizer |
|
|
elif kernel_init_method == 'central_frame_initializer': |
|
|
logging.info('Initializing input kernel to select centre frame.') |
|
|
central_time_index = input_kernel.shape[0] // 2 |
|
|
temp = np.zeros(input_kernel.shape) |
|
|
temp[central_time_index] = restored_kernel.copy() |
|
|
restored_kernel = temp |
|
|
else: |
|
|
raise AssertionError( |
|
|
'Unknown input kernel initialization {}'.format(kernel_init_method)) |
|
|
|
|
|
to_params[name]['kernel'] = restored_kernel |
|
|
to_params[name]['bias'] = restored_bias |
|
|
else: |
|
|
logging.info('Not restoring input embedding parameters') |
|
|
|
|
|
|
|
|
def init_encoderblock(to_params, from_params, tm_key): |
|
|
"""Initialize encoder_block_parameters.""" |
|
|
|
|
|
|
|
|
|
|
|
for enc_key in from_params[tm_key].keys(): |
|
|
restoring_params = False |
|
|
if tm_key in to_params: |
|
|
assert enc_key in to_params[tm_key], '%s not in to_params[%s]' % ( |
|
|
enc_key, tm_key) |
|
|
to_params[tm_key][enc_key] = from_params[tm_key][enc_key] |
|
|
restoring_params = True |
|
|
if '%s_spectrogram' % tm_key in to_params: |
|
|
assert enc_key in to_params['%s_spectrogram' % |
|
|
tm_key], '%s not in to_params[%s]' % ( |
|
|
enc_key, '%s_spectrogram' % tm_key) |
|
|
to_params['%s_spectrogram' % |
|
|
tm_key][enc_key] = from_params[tm_key][enc_key] |
|
|
restoring_params = True |
|
|
if not restoring_params: |
|
|
logging.info('Warning: Not restoring encoder parameters.') |
|
|
|