|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Vision Transformer.""" |
|
|
|
|
|
from typing import Any, Callable, Optional, Sequence |
|
|
|
|
|
from absl import logging |
|
|
import flax |
|
|
import flax.linen as nn |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import ml_collections |
|
|
import numpy as np |
|
|
from scenic.model_lib.base_models.multilabel_classification_model import MultiLabelClassificationModel |
|
|
from scenic.model_lib.layers import attention_layers |
|
|
from scenic.model_lib.layers import nn_layers |
|
|
import scipy |
|
|
from tensorflow.io import gfile |
|
|
|
|
|
Initializer = Callable[[jnp.ndarray, Sequence[int], jnp.dtype], jnp.ndarray] |
|
|
|
|
|
|
|
|
class AddPositionEmbs(nn.Module): |
|
|
"""Adds learned positional embeddings to the inputs. |
|
|
|
|
|
Attributes: |
|
|
posemb_init: Positional embedding initializer. |
|
|
|
|
|
Returns: |
|
|
Output in shape `[bs, timesteps, in_dim]`. |
|
|
""" |
|
|
posemb_init: Initializer = nn.initializers.normal(stddev=0.02) |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: |
|
|
|
|
|
assert inputs.ndim == 3, ('Number of dimensions should be 3,' |
|
|
' but it is: %d' % inputs.ndim) |
|
|
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) |
|
|
pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape, |
|
|
inputs.dtype) |
|
|
return inputs + pe |
|
|
|
|
|
|
|
|
class MAPHead(nn.Module): |
|
|
"""Multihead Attention Pooling.""" |
|
|
mlp_dim: Optional[int] = None |
|
|
num_heads: int = 12 |
|
|
dtype: Any = jnp.float32 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
n, _, d = x.shape |
|
|
probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), |
|
|
x.dtype) |
|
|
probe = jnp.tile(probe, [n, 1, 1]) |
|
|
|
|
|
x = nn.MultiHeadDotProductAttention( |
|
|
num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform() |
|
|
)(probe, x) |
|
|
|
|
|
y = nn.LayerNorm()(x) |
|
|
x = x + attention_layers.MlpBlock( |
|
|
mlp_dim=self.mlp_dim, |
|
|
dtype=self.dtype, |
|
|
dropout_rate=0.0)(y, deterministic=True) |
|
|
return x[:, 0] |
|
|
|
|
|
|
|
|
class Encoder1DBlock(nn.Module): |
|
|
"""Transformer encoder layer. |
|
|
|
|
|
Attributes: |
|
|
mlp_dim: Dimension of the mlp on top of attention block. |
|
|
num_heads: Number of self-attention heads. |
|
|
dtype: The dtype of the computation (default: float32). |
|
|
dropout_rate: Dropout rate. |
|
|
attention_dropout_rate: Dropout for attention heads. |
|
|
stochastic_depth: probability of dropping a layer linearly grows |
|
|
from 0 to the provided value. |
|
|
|
|
|
Returns: |
|
|
output after transformer encoder block. |
|
|
""" |
|
|
mlp_dim: int |
|
|
num_heads: int |
|
|
dtype: Any = jnp.float32 |
|
|
dropout_rate: float = 0.1 |
|
|
attention_dropout_rate: float = 0.1 |
|
|
stochastic_depth: float = 0.0 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray: |
|
|
"""Applies Encoder1DBlock module. |
|
|
|
|
|
Args: |
|
|
inputs: Input data. |
|
|
deterministic: Deterministic or not (to apply dropout). |
|
|
|
|
|
Returns: |
|
|
Output after transformer encoder block. |
|
|
""" |
|
|
|
|
|
assert inputs.ndim == 3 |
|
|
x = nn.LayerNorm(dtype=self.dtype)(inputs) |
|
|
x = nn.MultiHeadDotProductAttention( |
|
|
num_heads=self.num_heads, |
|
|
dtype=self.dtype, |
|
|
kernel_init=nn.initializers.xavier_uniform(), |
|
|
broadcast_dropout=False, |
|
|
deterministic=deterministic, |
|
|
dropout_rate=self.attention_dropout_rate)(x, x) |
|
|
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic) |
|
|
x = nn_layers.StochasticDepth(rate=self.stochastic_depth)(x, deterministic) |
|
|
x = x + inputs |
|
|
|
|
|
|
|
|
y = nn.LayerNorm(dtype=self.dtype)(x) |
|
|
y = attention_layers.MlpBlock( |
|
|
mlp_dim=self.mlp_dim, |
|
|
dtype=self.dtype, |
|
|
dropout_rate=self.dropout_rate, |
|
|
activation_fn=nn.gelu, |
|
|
kernel_init=nn.initializers.xavier_uniform(), |
|
|
bias_init=nn.initializers.normal(stddev=1e-6))( |
|
|
y, deterministic=deterministic) |
|
|
y = nn_layers.StochasticDepth(rate=self.stochastic_depth)(y, deterministic) |
|
|
return y + x |
|
|
|
|
|
|
|
|
class Encoder(nn.Module): |
|
|
"""Transformer Encoder. |
|
|
|
|
|
Attributes: |
|
|
num_layers: Number of layers. |
|
|
mlp_dim: Dimension of the mlp on top of attention block. |
|
|
num_heads: The number of heads for multi-head self-attention. |
|
|
positional_embedding: The type of positional embeddings to add to the |
|
|
input tokens. Options are {learned_1d, sinusoidal_2d, none}. |
|
|
dropout_rate: Dropout rate. |
|
|
stochastic_depth: probability of dropping a layer linearly grows |
|
|
from 0 to the provided value. Our implementation of stochastic depth |
|
|
follows timm library, which does per-example layer dropping and uses |
|
|
independent dropping patterns for each skip-connection. |
|
|
dtype: Dtype of activations. |
|
|
has_cls_token: Whether or not the sequence is prepended with a CLS token. |
|
|
""" |
|
|
num_layers: int |
|
|
mlp_dim: int |
|
|
num_heads: int |
|
|
positional_embedding: str = 'learned_1d' |
|
|
dropout_rate: float = 0.1 |
|
|
attention_dropout_rate: float = 0.1 |
|
|
stochastic_depth: float = 0.0 |
|
|
dtype: Any = jnp.float32 |
|
|
has_cls_token: bool = False |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, inputs: jnp.ndarray, *, train: bool = False): |
|
|
"""Applies Transformer model on the inputs. |
|
|
|
|
|
Args: |
|
|
inputs: Input tokens of shape [batch, num_tokens, channels]. |
|
|
train: If in training mode, dropout and stochastic depth is applied. |
|
|
|
|
|
Returns: |
|
|
Encoded tokens. |
|
|
""" |
|
|
|
|
|
assert inputs.ndim == 3 |
|
|
dtype = jax.dtypes.canonicalize_dtype(self.dtype) |
|
|
|
|
|
|
|
|
if self.positional_embedding == 'learned_1d': |
|
|
x = AddPositionEmbs( |
|
|
posemb_init=nn.initializers.normal(stddev=0.02), |
|
|
name='posembed_input')( |
|
|
inputs) |
|
|
elif self.positional_embedding == 'sinusoidal_1d': |
|
|
x = attention_layers.Add1DPositionEmbedding(posemb_init=None)(inputs) |
|
|
elif self.positional_embedding == 'sinusoidal_2d': |
|
|
batch, num_tokens, hidden_dim = inputs.shape |
|
|
if self.has_cls_token: |
|
|
num_tokens -= 1 |
|
|
height = width = int(np.sqrt(num_tokens)) |
|
|
if height * width != num_tokens: |
|
|
raise ValueError('Input is assumed to be square for sinusoidal init.') |
|
|
if self.has_cls_token: |
|
|
inputs_reshape = inputs[:, 1:].reshape( |
|
|
[batch, height, width, hidden_dim] |
|
|
) |
|
|
x = attention_layers.AddFixedSinCosPositionEmbedding()(inputs_reshape) |
|
|
x = x.reshape([batch, num_tokens, hidden_dim]) |
|
|
x = jnp.concatenate([inputs[:, :1], x], axis=1) |
|
|
else: |
|
|
inputs_reshape = inputs.reshape([batch, height, width, hidden_dim]) |
|
|
x = attention_layers.AddFixedSinCosPositionEmbedding()(inputs_reshape) |
|
|
x = x.reshape([batch, num_tokens, hidden_dim]) |
|
|
elif self.positional_embedding == 'none': |
|
|
x = inputs |
|
|
else: |
|
|
raise ValueError('Unknown positional embedding: ' |
|
|
f'{self.positional_embedding}') |
|
|
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) |
|
|
|
|
|
|
|
|
for lyr in range(self.num_layers): |
|
|
x = Encoder1DBlock( |
|
|
mlp_dim=self.mlp_dim, |
|
|
num_heads=self.num_heads, |
|
|
dropout_rate=self.dropout_rate, |
|
|
attention_dropout_rate=self.attention_dropout_rate, |
|
|
stochastic_depth=(lyr / max(self.num_layers - 1, 1)) |
|
|
* self.stochastic_depth, |
|
|
name=f'encoderblock_{lyr}', |
|
|
dtype=dtype, |
|
|
)(x, deterministic=not train) |
|
|
encoded = nn.LayerNorm(name='encoder_norm')(x) |
|
|
return encoded |
|
|
|
|
|
|
|
|
class ViT(nn.Module): |
|
|
"""Vision Transformer model. |
|
|
|
|
|
Attributes: |
|
|
num_classes: Number of output classes. |
|
|
mlp_dim: Dimension of the mlp on top of attention block. |
|
|
num_layers: Number of layers. |
|
|
num_heads: Number of self-attention heads. |
|
|
patches: Configuration of the patches extracted in the stem of the model. |
|
|
hidden_size: Size of the hidden state of the output of model's stem. |
|
|
positional_embedding: The type of positional embeddings to add to the |
|
|
tokens at the beginning of the transformer encoder. Options are |
|
|
{learned_1d, sinusoidal_2d, none}. |
|
|
representation_size: Size of the representation layer in the model's head. |
|
|
if None, we skip the extra projection + tanh activation at the end. |
|
|
dropout_rate: Dropout rate. |
|
|
attention_dropout_rate: Dropout for attention heads. |
|
|
classifier: type of the classifier layer. Options are 'gap', 'gmp', 'gsp', |
|
|
'token', 'none'. |
|
|
dtype: JAX data type for activations. |
|
|
""" |
|
|
|
|
|
num_classes: int |
|
|
mlp_dim: int |
|
|
num_layers: int |
|
|
num_heads: int |
|
|
patches: ml_collections.ConfigDict |
|
|
hidden_size: int |
|
|
positional_embedding: str = 'learned_1d' |
|
|
representation_size: Optional[int] = None |
|
|
dropout_rate: float = 0.1 |
|
|
attention_dropout_rate: float = 0.1 |
|
|
stochastic_depth: float = 0.0 |
|
|
classifier: str = 'gap' |
|
|
dtype: Any = jnp.float32 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x: jnp.ndarray, *, train: bool, debug: bool = False): |
|
|
|
|
|
fh, fw = self.patches.size |
|
|
|
|
|
x = nn.Conv( |
|
|
self.hidden_size, (fh, fw), |
|
|
strides=(fh, fw), |
|
|
padding='VALID', |
|
|
name='embedding')( |
|
|
x) |
|
|
n, h, w, c = x.shape |
|
|
x = jnp.reshape(x, [n, h * w, c]) |
|
|
|
|
|
|
|
|
if self.classifier == 'token': |
|
|
cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype) |
|
|
cls = jnp.tile(cls, [n, 1, 1]) |
|
|
x = jnp.concatenate([cls, x], axis=1) |
|
|
|
|
|
x = Encoder( |
|
|
mlp_dim=self.mlp_dim, |
|
|
num_layers=self.num_layers, |
|
|
num_heads=self.num_heads, |
|
|
positional_embedding=self.positional_embedding, |
|
|
dropout_rate=self.dropout_rate, |
|
|
attention_dropout_rate=self.attention_dropout_rate, |
|
|
stochastic_depth=self.stochastic_depth, |
|
|
dtype=self.dtype, |
|
|
has_cls_token=self.classifier == 'token', |
|
|
name='Transformer', |
|
|
)(x, train=train) |
|
|
|
|
|
if self.classifier in ('token', '0'): |
|
|
x = x[:, 0] |
|
|
elif self.classifier in ('gap', 'gmp', 'gsp'): |
|
|
fn = {'gap': jnp.mean, 'gmp': jnp.max, 'gsp': jnp.sum}[self.classifier] |
|
|
x = fn(x, axis=1) |
|
|
elif self.classifier == 'map': |
|
|
x = MAPHead( |
|
|
num_heads=self.num_heads, mlp_dim=self.mlp_dim, dtype=self.dtype)(x) |
|
|
elif self.classifier == 'none': |
|
|
pass |
|
|
else: |
|
|
raise ValueError(f'Unknown classifier {self.classifier}') |
|
|
|
|
|
if self.representation_size is not None: |
|
|
x = nn.Dense(self.representation_size, name='pre_logits')(x) |
|
|
x = nn.tanh(x) |
|
|
else: |
|
|
x = nn_layers.IdentityLayer(name='pre_logits')(x) |
|
|
|
|
|
if self.num_classes > 0: |
|
|
|
|
|
x = nn.Dense( |
|
|
self.num_classes, |
|
|
kernel_init=nn.initializers.zeros, |
|
|
name='output_projection')( |
|
|
x) |
|
|
return x |
|
|
|
|
|
|
|
|
class ViTMultiLabelClassificationModel(MultiLabelClassificationModel): |
|
|
"""Vision Transformer model for multi-label classification task.""" |
|
|
|
|
|
def build_flax_model(self)-> nn.Module: |
|
|
dtype_str = self.config.get('model_dtype_str', 'float32') |
|
|
if dtype_str != 'float32': |
|
|
raise ValueError('`dtype` argument is not propagated properly ' |
|
|
'in the current implmentation, so only ' |
|
|
'`float32` is supported for now.') |
|
|
return ViT( |
|
|
num_classes=self.dataset_meta_data['num_classes'], |
|
|
mlp_dim=self.config.model.mlp_dim, |
|
|
num_layers=self.config.model.num_layers, |
|
|
num_heads=self.config.model.num_heads, |
|
|
positional_embedding=self.config.model.get('positional_embedding', |
|
|
'learned_1d'), |
|
|
representation_size=self.config.model.representation_size, |
|
|
patches=self.config.model.patches, |
|
|
hidden_size=self.config.model.hidden_size, |
|
|
classifier=self.config.model.classifier, |
|
|
dropout_rate=self.config.model.get('dropout_rate'), |
|
|
attention_dropout_rate=self.config.model.get('attention_dropout_rate'), |
|
|
stochastic_depth=self.config.model.get('stochastic_depth', 0.0), |
|
|
dtype=getattr(jnp, dtype_str), |
|
|
) |
|
|
|
|
|
def default_flax_model_config(self) -> ml_collections.ConfigDict: |
|
|
return ml_collections.ConfigDict({ |
|
|
'model': |
|
|
dict( |
|
|
num_heads=2, |
|
|
num_layers=1, |
|
|
representation_size=16, |
|
|
mlp_dim=32, |
|
|
dropout_rate=0., |
|
|
attention_dropout_rate=0., |
|
|
hidden_size=16, |
|
|
patches={'size': (4, 4)}, |
|
|
classifier='gap', |
|
|
data_dtype_str='float32') |
|
|
}) |
|
|
|
|
|
def init_from_train_state( |
|
|
self, train_state: Any, restored_train_state: Any, |
|
|
restored_model_cfg: ml_collections.ConfigDict) -> Any: |
|
|
"""Updates the train_state with data from restored_train_state. |
|
|
|
|
|
This function is writen to be used for 'fine-tuning' experiments. Here, we |
|
|
do some surgery to support larger resolutions (longer sequence length) in |
|
|
the transformer block, with respect to the learned pos-embeddings. |
|
|
|
|
|
Args: |
|
|
train_state: A raw TrainState for the model. |
|
|
restored_train_state: A TrainState that is loaded with parameters/state of |
|
|
a pretrained model. |
|
|
restored_model_cfg: Configuration of the model from which the |
|
|
restored_train_state come from. Usually used for some asserts. |
|
|
|
|
|
Returns: |
|
|
Updated train_state. |
|
|
""" |
|
|
return init_vit_from_train_state(train_state, restored_train_state, |
|
|
self.config, restored_model_cfg) |
|
|
|
|
|
def load_augreg_params(self, train_state: Any, params_path: str, |
|
|
model_cfg: ml_collections.ConfigDict) -> Any: |
|
|
"""Loads parameters from an AugReg checkpoint. |
|
|
|
|
|
See |
|
|
https://github.com/google-research/vision_transformer/ |
|
|
and |
|
|
https://arxiv.org/abs/2106.10270 |
|
|
for more information about these pre-trained models. |
|
|
|
|
|
Args: |
|
|
train_state: A raw TrainState for the model. |
|
|
params_path: Path to an Augreg checkpoint. The model config is read from |
|
|
the filename (e.g. a B/32 model starts with "B_32-"). |
|
|
model_cfg: Configuration of the model. Usually used for some asserts. |
|
|
|
|
|
Returns: |
|
|
Updated train_state with params replaced with the ones read from the |
|
|
AugReg checkpoint. |
|
|
""" |
|
|
restored_model_cfg = _get_augreg_cfg(params_path) |
|
|
assert tuple(restored_model_cfg.patches.size) == tuple( |
|
|
model_cfg.patches.size) |
|
|
assert restored_model_cfg.hidden_size == model_cfg.hidden_size |
|
|
assert restored_model_cfg.mlp_dim == model_cfg.mlp_dim |
|
|
assert restored_model_cfg.num_layers == model_cfg.num_layers |
|
|
assert restored_model_cfg.num_heads == model_cfg.num_heads |
|
|
assert restored_model_cfg.classifier == model_cfg.classifier |
|
|
|
|
|
flattened = np.load(gfile.GFile(params_path, 'rb')) |
|
|
restored_params = flax.traverse_util.unflatten_dict( |
|
|
{tuple(k.split('/')): v for k, v in flattened.items()}) |
|
|
restored_params['output_projection'] = restored_params.pop('head') |
|
|
|
|
|
if 'optimizer' in train_state: |
|
|
|
|
|
params = flax.core.unfreeze(train_state.optimizer.target) |
|
|
_merge_params(params, restored_params, model_cfg, restored_model_cfg) |
|
|
return train_state.replace( |
|
|
optimizer=train_state.optimizer.replace( |
|
|
target=flax.core.freeze(params))) |
|
|
else: |
|
|
params = flax.core.unfreeze(train_state.params) |
|
|
_merge_params(params, restored_params, model_cfg, restored_model_cfg) |
|
|
return train_state.replace(params=flax.core.freeze(params)) |
|
|
|
|
|
|
|
|
def _get_augreg_cfg(params_path): |
|
|
model = params_path.split('/')[-1].split('-')[0] |
|
|
assert model in ('B_16', 'B_32'), ( |
|
|
'Currently only B/16 and B/32 models are supported.') |
|
|
sz = {'B_16': 16, 'B_32': 32}[model] |
|
|
return ml_collections.ConfigDict( |
|
|
dict( |
|
|
num_classes=0, |
|
|
mlp_dim=3072, |
|
|
num_layers=12, |
|
|
num_heads=12, |
|
|
hidden_size=768, |
|
|
classifier='token', |
|
|
patches=dict(size=(sz, sz)), |
|
|
dropout_rate=0., |
|
|
attention_dropout_rate=0., |
|
|
)) |
|
|
|
|
|
|
|
|
def _merge_params(params, restored_params, model_cfg, restored_model_cfg): |
|
|
"""Merges `restored_params` into `params`.""" |
|
|
|
|
|
for m_key, m_params in restored_params.items(): |
|
|
if m_key == 'output_projection': |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
elif m_key == 'pre_logits': |
|
|
if model_cfg.model.representation_size is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
params.pop(m_key, None) |
|
|
else: |
|
|
assert restored_model_cfg.model.representation_size |
|
|
params[m_key] = m_params |
|
|
|
|
|
elif m_key == 'Transformer': |
|
|
for tm_key, tm_params in m_params.items(): |
|
|
if tm_key == 'posembed_input': |
|
|
posemb = params[m_key]['posembed_input']['pos_embedding'] |
|
|
restored_posemb = m_params['posembed_input']['pos_embedding'] |
|
|
|
|
|
if restored_posemb.shape != posemb.shape: |
|
|
|
|
|
logging.info('Resized variant: %s to %s', restored_posemb.shape, |
|
|
posemb.shape) |
|
|
ntok = posemb.shape[1] |
|
|
if restored_model_cfg.model.classifier == 'token': |
|
|
|
|
|
restored_posemb_grid = restored_posemb[0, 1:] |
|
|
if model_cfg.model.classifier == 'token': |
|
|
|
|
|
cls_tok = restored_posemb[:, :1] |
|
|
ntok -= 1 |
|
|
else: |
|
|
|
|
|
cls_tok = restored_posemb[:, :0] |
|
|
else: |
|
|
restored_posemb_grid = restored_posemb[0] |
|
|
if model_cfg.model.classifier == 'token': |
|
|
|
|
|
cls_tok = posemb[:, :1] |
|
|
ntok -= 1 |
|
|
else: |
|
|
|
|
|
cls_tok = restored_posemb[:, :0] |
|
|
|
|
|
restored_gs = int(np.sqrt(len(restored_posemb_grid))) |
|
|
gs = int(np.sqrt(ntok)) |
|
|
if restored_gs != gs: |
|
|
logging.info('Grid-size from %s to %s.', restored_gs, gs) |
|
|
restored_posemb_grid = restored_posemb_grid.reshape( |
|
|
restored_gs, restored_gs, -1) |
|
|
zoom = (gs / restored_gs, gs / restored_gs, 1) |
|
|
restored_posemb_grid = scipy.ndimage.zoom( |
|
|
restored_posemb_grid, zoom, order=1) |
|
|
|
|
|
restored_posemb_grid = restored_posemb_grid.reshape( |
|
|
1, gs * gs, -1) |
|
|
restored_posemb = jnp.array( |
|
|
np.concatenate([cls_tok, restored_posemb_grid], axis=1)) |
|
|
|
|
|
params[m_key][tm_key]['pos_embedding'] = restored_posemb |
|
|
|
|
|
elif tm_key in params[m_key]: |
|
|
params[m_key][tm_key] = tm_params |
|
|
else: |
|
|
logging.info('Ignoring %s. In restored model\'s Transformer,' |
|
|
'but not in target', m_key) |
|
|
|
|
|
elif m_key in params: |
|
|
|
|
|
params[m_key] = m_params |
|
|
|
|
|
else: |
|
|
logging.info('Ignoring %s. In restored model, but not in target', m_key) |
|
|
|
|
|
|
|
|
def init_vit_from_train_state( |
|
|
train_state: Any, restored_train_state: Any, |
|
|
model_cfg: ml_collections.ConfigDict, |
|
|
restored_model_cfg: ml_collections.ConfigDict) -> Any: |
|
|
"""Updates the train_state with data from restored_train_state. |
|
|
|
|
|
This function is written to be used for 'fine-tuning' experiments. Here, we |
|
|
do some surgery to support larger resolutions (longer sequence length) in |
|
|
the transformer block, with respect to the learned pos-embeddings. |
|
|
|
|
|
The function supports train_states using either Optax or flax.optim (which |
|
|
has been deprecated, and will be removed from Scenic.) |
|
|
|
|
|
Args: |
|
|
train_state: A raw TrainState for the model. |
|
|
restored_train_state: A TrainState that is loaded with parameters/state of a |
|
|
pretrained model. |
|
|
model_cfg: Configuration of the model. Usually used for some asserts. |
|
|
restored_model_cfg: Configuration of the model from which the |
|
|
restored_train_state come from. Usually used for some asserts. |
|
|
|
|
|
Returns: |
|
|
Updated train_state. |
|
|
""" |
|
|
if hasattr(train_state, 'optimizer'): |
|
|
|
|
|
params = flax.core.unfreeze(train_state.optimizer.target) |
|
|
restored_params = flax.core.unfreeze(restored_train_state.optimizer.target) |
|
|
_merge_params(params, restored_params, model_cfg, restored_model_cfg) |
|
|
return train_state.replace( |
|
|
optimizer=train_state.optimizer.replace( |
|
|
target=flax.core.freeze(params))) |
|
|
else: |
|
|
params = flax.core.unfreeze(train_state.params) |
|
|
restored_params = flax.core.unfreeze(restored_train_state.params) |
|
|
_merge_params(params, restored_params, model_cfg, restored_model_cfg) |
|
|
return train_state.replace(params=flax.core.freeze(params)) |
|
|
|