owlv2 / scenic /projects /mbt /model_utils.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model utils for MBT."""
from typing import Any
from absl import logging
import flax
import jax.numpy as jnp
import ml_collections
import numpy as np
from scenic.common_lib import debug_utils
from scenic.projects.vivit import model_utils as vivit_utils
import scipy
# Shared utils with ViViT
central_frame_initializer = vivit_utils.central_frame_initializer
average_frame_initializer = vivit_utils.average_frame_initializer
tile_positional_embeddings = vivit_utils.tile_positional_embeddings
def interpolate_positional_embeddings(restored_posemb_grid, n_tokens):
"""Interpolate positional embeddings from one size to another.
Args:
restored_posemb_grid: Positional embeddings from restored model. Shape is
[n_restored_tokens, d]. It is assumed that the restored model used square
image patches.
n_tokens: Number of tokens in the target model. Can be a scalar if the
target image is square, otherwise should be a tuple of 2.
Returns:
positional embedding resized to match n_tokens. Shape is [1, n_tokens, d]
"""
restored_gs = int(np.sqrt(len(restored_posemb_grid)))
if isinstance(n_tokens, tuple):
gh, gw = n_tokens
else:
if n_tokens == len(restored_posemb_grid):
# No need to interpolate
return np.expand_dims(restored_posemb_grid, axis=0)
gh = int(np.sqrt(n_tokens))
gw = n_tokens // gh
assert gh * gw == n_tokens
logging.info('Resizing grid-size from (%s, %s) to (%s, %s).',
restored_gs, restored_gs, gh, gw)
restored_posemb_grid = restored_posemb_grid.reshape(restored_gs, restored_gs,
-1)
zoom = (gh / restored_gs, gw / restored_gs, 1)
restored_posemb_grid = scipy.ndimage.zoom(restored_posemb_grid, zoom, order=1)
restored_posemb_grid = restored_posemb_grid.reshape(1, gh * gw, -1)
return restored_posemb_grid
def initialise_from_train_state(
config,
train_state: Any,
restored_train_state: Any,
restored_model_cfg: ml_collections.ConfigDict,
restore_output_proj: bool,
mbt_transformer_key: str = 'Transformer',
log_initialised_param_shapes: bool = True,
one_config: bool = True,
prefix_path: Any = None) -> Any:
"""Updates the train_state with data from restored_train_state.
This function is written to be used for 'fine-tuning' experiments. Here, we
do some surgery to support larger resolutions (longer sequence length) in
the transformer block, with respect to the learned pos-embeddings.
Args:
config: Configurations for the model being updated, or tuple of configs.
train_state: A raw TrainState for the model.
restored_train_state: A TrainState that is loaded with parameters/state of a
pretrained model.
restored_model_cfg: Configuration of the model from which the
restored_train_state come from. Usually used for some asserts.
restore_output_proj: If true, load the final output projection. Set
to False if finetuning to a new dataset.
mbt_transformer_key: The key used for storing the subtree in the
parameters that keeps Transformer weights, that are supposed to be
initialized from the given pre-trained model.
log_initialised_param_shapes: If true, print tabular summary of all the
variables in the model once they have been initialised.
one_config: If true, we have only a single config. If false, we get a tuple
of configs in the order [init_config, model_config, dataset_config]. This
is useful for works that build upon MBT and have different models in their
config.
prefix_path: If parameters are in a subtree.
Returns:
Updated train_state.
"""
# Split up configs
if one_config:
init_config = config.init_from
model_config = config.model
dataset_config = config.dataset_configs
else:
init_config, model_config, dataset_config = config
# Inspect and compare the parameters of the model with the init-model
params = flax.core.unfreeze(train_state.optimizer.target)
logging.info('Parameters in the target model are: %s', params)
if init_config.get('checkpoint_format', 'scenic') == 'big_vision':
restored_params = restored_train_state.optimizer['target']
else:
restored_params = restored_train_state.optimizer.target
restored_params = flax.core.unfreeze(restored_params)
if init_config.get('init_from_vit', True):
if prefix_path:
video_params = params[prefix_path]
else:
video_params = params
# Start moving parameters, one-by-one and apply changes if needed
for m_key, m_params in restored_params.items():
if m_key == 'output_projection':
if restore_output_proj:
video_params[m_key] = m_params
else:
pass
elif m_key == 'pre_logits':
if model_config.representation_size is None:
# We don't have representation_size in the new model, so let's ignore
# if from the pretained model, in case it has it.
# Note, removing the key from the dictionary is necessary to prevent
# obscure errors from the Flax optimizer.
video_params.pop(m_key, None)
else:
assert restored_model_cfg.model.representation_size
video_params[m_key] = m_params
elif m_key in ['Transformer']:
for tm_key, tm_params in m_params.items():
if tm_key == 'posembed_input': # Might need resolution change
init_posemb(
video_params[mbt_transformer_key],
m_params,
init_config,
model_config,
dataset_config,
restored_model_cfg,
'posembed_input',
prefix_path=prefix_path)
init_posemb(
video_params[mbt_transformer_key],
m_params,
init_config,
model_config,
dataset_config,
restored_model_cfg,
'posembed_input_spectrogram',
prefix_path=prefix_path)
init_posemb(
video_params,
m_params,
init_config,
model_config,
dataset_config,
restored_model_cfg,
'bottleneck',
prefix_path=prefix_path)
elif 'encoderblock' in tm_key:
logging.info('Loading encoder parameters.')
init_encoderblock(video_params[mbt_transformer_key], m_params,
tm_key)
else: # Other parameters of the Transformer encoder
video_params[mbt_transformer_key][tm_key] = tm_params
elif m_key == 'embedding':
init_embedding(video_params, m_params, init_config, model_config,
'embedding')
init_embedding(video_params, m_params, init_config, model_config,
'embedding_spectrogram')
else:
if m_key in train_state.optimizer.target:
video_params[m_key] = m_params
if '%s_spectrogram' % m_key in train_state.optimizer.target:
video_params['%s_spectrogram' % m_key] = m_params
else:
logging.info('Skipping %s. In restored model but not in target',
m_key)
else:
for m_key, m_params in restored_params.items():
if m_key == 'output_projection':
if restore_output_proj:
params[m_key] = m_params
else:
pass
elif m_key == 'pre_logits':
if model_config.representation_size is None:
# We don't have representation_size in the new model, so let's ignore
# if from the pretained model, in case it has it.
# Note, removing the key from the dictionary is necessary to prevent
# obscure errors from the Flax optimizer.
params.pop(m_key, None)
else:
assert restored_model_cfg.model.representation_size
params[m_key] = m_params
else:
if m_key in train_state.optimizer.target:
params[m_key] = m_params
else:
logging.info('Skipping %s. In restored model but not in target',
m_key)
if log_initialised_param_shapes:
logging.info('Parameter summary after initialising from train state')
debug_utils.log_param_shapes(params)
return train_state.replace(
optimizer=train_state.optimizer.replace(target=flax.core.freeze(params)))
def init_posemb(to_params, from_params, init_config, model_config,
dataset_config, restored_model_cfg, name, prefix_path=None):
"""Initialize the positional embeddings."""
if name not in to_params:
logging.info('No %s in target model', name)
elif init_config.restore_positional_embedding:
if name == 'bottleneck':
posemb = to_params[name]
else:
posemb = to_params[name]['pos_embedding']
restored_posemb = from_params['posembed_input']['pos_embedding']
if restored_posemb.shape != posemb.shape:
# Rescale the grid of pos, embeddings.
# Default parameter shape is (1, N, 768)
logging.info('Adapting positional embeddings %s from %s to %s',
name, restored_posemb.shape, posemb.shape)
ntok = posemb.shape[1]
if prefix_path:
# MBT is part of a larger model
classifier = restored_model_cfg.mbt.model.classifier
else:
classifier = restored_model_cfg.model.classifier
if classifier == 'token':
# the first token is the CLS token
cls_tok = restored_posemb[:, :1]
restored_posemb_grid = restored_posemb[0, 1:]
else:
cls_tok = restored_posemb[:, :0]
restored_posemb_grid = restored_posemb[0]
if model_config.classifier == 'token':
ntok -= 1
size_change = init_config.positional_embed_size_change
if name == 'bottleneck':
restored_posemb_grid = interpolate_positional_embeddings(
restored_posemb_grid, ntok)
elif size_change == 'tile':
restored_posemb_grid = tile_positional_embeddings(
restored_posemb_grid, ntok)
elif size_change in ['resize_tile', 'resize']:
temp_encoding = model_config.temporal_encoding_config
if name.find('spectrogram') > -1:
gh = ((dataset_config.spec_shape[0] *
dataset_config.num_spec_frames) //
model_config.patches.size[0])
gw = (dataset_config.spec_shape[1] //
model_config.patches.size[1])
tokens_per_frame = (gh, gw)
elif temp_encoding.method == 'temporal_sampling':
tokens_per_frame = int(ntok / temp_encoding.n_sampled_frames)
elif temp_encoding.method == '3d_conv':
# This is for RGB only.
n_frames = (
dataset_config.num_frames //
model_config.patches.size[2])
tokens_per_frame = ntok // n_frames
else:
raise AssertionError(
f'Unknown temporal encoding {temp_encoding.method}')
restored_posemb_grid = interpolate_positional_embeddings(
restored_posemb_grid, tokens_per_frame)
if size_change == 'resize_tile' and ntok != tokens_per_frame:
restored_posemb_grid = restored_posemb_grid[0]
restored_posemb_grid = tile_positional_embeddings(
restored_posemb_grid, ntok)
else:
raise AssertionError(
'Unknown positional embedding size changing method')
# attach the CLS token again
if model_config.classifier == 'token':
restored_posemb = jnp.array(
np.concatenate([cls_tok, restored_posemb_grid], axis=1))
else:
restored_posemb = restored_posemb_grid
if name == 'bottleneck':
to_params[name] = restored_posemb
else:
to_params[name]['pos_embedding'] = restored_posemb
else:
logging.info('Not restoring positional encodings from pretrained model')
def init_embedding(to_params, from_params, init_config, model_config, name):
"""Initialize input embedding."""
if name not in to_params:
logging.info('No %s in target model', name)
elif init_config.get('restore_input_embedding', True):
input_kernel = to_params[name]['kernel']
restored_kernel = from_params['kernel']
restored_bias = from_params['bias']
if input_kernel.shape != restored_kernel.shape:
kernel_init_method = model_config.temporal_encoding_config.kernel_init_method
if input_kernel.shape == restored_kernel.shape[1:]:
# Deflates a ViViT 3D embedder to work with 2D spectrogram inputs.
restored_kernel = np.mean(restored_kernel, axis=0)
elif input_kernel.shape[1:] != restored_kernel.shape:
# Kernel dimensions are [t, c_in, c_out]
restored_kernel = np.reshape(restored_kernel, input_kernel.shape)
elif input_kernel.shape[0] == 1:
# Kernel dimensions are [t, h, w, c_in, c_out]
restored_kernel = np.expand_dims(restored_kernel, axis=0)
elif kernel_init_method == 'average_frame_initializer':
# This corresponds to "filter inflation" in
# J Carreira and A Zisserman. Quo vadis, action recognition?
# A new model and the kinetics dataset. CVPR 2017"
logging.info('Initializing input kernel with filter inflation.')
t = input_kernel.shape[0]
restored_kernel = np.expand_dims(restored_kernel, axis=0)
restored_kernel = np.tile(restored_kernel, [t, 1, 1, 1, 1]) / t
elif kernel_init_method == 'average_arp_frame_initializer':
# This corresponds to a combination of filter inflation and
# the approximate rank pooling described in
# H Bilen et al. Action Recognition with Dynamic Image Networks.
# PAMI 2017.
logging.info('Initialzing input kernel with ARP inflation')
t = input_kernel.shape[0]
restored_kernel = np.expand_dims(restored_kernel, axis=0)
restored_kernel = np.tile(restored_kernel, [t, 1, 1, 1, 1])
def average_arp(length):
# Implements Equation 3 of Bilen et al. PAMI 2017
array = np.arange(1, length + 1)
harmonic = np.zeros((length + 1))
harmonic[1:] = np.cumsum(1.0 / array)
array = 2 * (length - array + 1) - (length + 1) * (
harmonic[-1] - harmonic[:-1])
return array
normalizer = average_arp(t) / t
normalizer = np.reshape(normalizer, [t, 1, 1, 1, 1])
restored_kernel = restored_kernel * normalizer
elif kernel_init_method == 'central_frame_initializer':
logging.info('Initializing input kernel to select centre frame.')
central_time_index = input_kernel.shape[0] // 2
temp = np.zeros(input_kernel.shape)
temp[central_time_index] = restored_kernel.copy()
restored_kernel = temp
else:
raise AssertionError(
'Unknown input kernel initialization {}'.format(kernel_init_method))
to_params[name]['kernel'] = restored_kernel
to_params[name]['bias'] = restored_bias
else:
logging.info('Not restoring input embedding parameters')
def init_encoderblock(to_params, from_params, tm_key):
"""Initialize encoder_block_parameters."""
# Explicitly enumerate over the keys in the encoder-block. Don't just
# assign the dictionary. It is possible for the target model to
# contain keys that are not in the restored model.
for enc_key in from_params[tm_key].keys():
restoring_params = False
if tm_key in to_params:
assert enc_key in to_params[tm_key], '%s not in to_params[%s]' % (
enc_key, tm_key)
to_params[tm_key][enc_key] = from_params[tm_key][enc_key]
restoring_params = True
if '%s_spectrogram' % tm_key in to_params:
assert enc_key in to_params['%s_spectrogram' %
tm_key], '%s not in to_params[%s]' % (
enc_key, '%s_spectrogram' % tm_key)
to_params['%s_spectrogram' %
tm_key][enc_key] = from_params[tm_key][enc_key]
restoring_params = True
if not restoring_params:
logging.info('Warning: Not restoring encoder parameters.')