fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Model."""
from typing import Any, Dict
from absl import logging
import flax
import flax.linen as nn
import jax.numpy as jnp
import ml_collections
from scenic.common_lib import debug_utils
from scenic.model_lib.base_models.classification_model import ClassificationModel
from scenic.model_lib.base_models.regression_model import RegressionModel
from scenic.projects.baselines.bert import bert_base_model
from scenic.projects.baselines.bert import layers
class BERT(nn.Module):
"""BERT."""
stem_config: ml_collections.ConfigDict
encoder_config: ml_collections.ConfigDict
head_config: ml_collections.ConfigDict
dtype: Any = jnp.float32
@nn.compact
def __call__(self,
inputs: Dict[str, jnp.ndarray],
*,
train: bool,
transfer_mode=False,
debug: bool = False):
x, word_embeddings = layers.Stem(
vocab_size=self.stem_config.vocab_size,
type_vocab_size=self.stem_config.type_vocab_size,
hidden_size=self.stem_config.hidden_size,
max_position_embeddings=self.stem_config.max_position_embeddings,
dropout_rate=self.stem_config.dropout_rate,
embedding_width=self.stem_config.get('embedding_width'),
dtype=self.dtype,
name='stem')(
input_word_ids=inputs['input_word_ids'],
input_type_ids=inputs['input_type_ids'],
input_mask=inputs['input_mask'],
train=train)
x = layers.BERTEncoder(
mlp_dim=self.encoder_config.mlp_dim,
num_layers=self.encoder_config.num_layers,
num_heads=self.encoder_config.num_heads,
dropout_rate=self.encoder_config.dropout_rate,
attention_dropout_rate=self.encoder_config.attention_dropout_rate,
pre_norm=self.encoder_config.pre_norm,
dtype=self.dtype,
name='bert_encoder')(
x, input_mask=inputs['input_mask'], train=train)
if self.head_config.type == 'pretraining':
next_sentence_prediction_logits = layers.ClassificationHead(
num_outputs=2,
hidden_sizes=(x.shape[-1], self.head_config.hidden_size),
nonlinearity=nn.tanh,
dtype=self.dtype,
name='next_sentence_prediction_head')(
x, train=train)
if transfer_mode:
# Next sentence prediction head is a classification head and we can
# reuse it for transfer evaluation on classification tasks.
return next_sentence_prediction_logits
masked_language_modeling_logits = layers.MaskedLanguageModelHead(
dtype=self.dtype, name='masked_language_model_head')(
x, inputs['masked_lm_positions'], word_embeddings, train=train)
return {
'nsp_logits': next_sentence_prediction_logits,
'mlm_logits': masked_language_modeling_logits
}
elif self.head_config.type == 'classification':
return layers.ClassificationHead(
num_outputs=self.head_config.num_classes,
hidden_sizes=(x.shape[-1], self.head_config.hidden_size),
nonlinearity=nn.tanh,
dtype=self.dtype,
name='classification_head')(
x, train=train)
elif self.head_config.type == 'regression':
return layers.ClassificationHead(
num_outputs=1,
hidden_sizes=(x.shape[-1], self.head_config.hidden_size),
nonlinearity=nn.tanh,
dtype=self.dtype,
name='regression_head')(
x, train=train)
class BERTModel(bert_base_model.BERTBaseModel):
"""BERT model."""
def build_flax_model(self):
model_dtype = getattr(jnp, self.config.get('model_dtype_str', 'float32'))
with self.config.unlocked():
# Add vocabulary information from dataset meta-data to configs:
self.config.model.stem.vocab_size = self.dataset_meta_data['vocab_size']
self.config.model.stem.type_vocab_size = self.dataset_meta_data[
'type_vocab_size']
return BERT(
stem_config=self.config.model.stem,
encoder_config=self.config.model.encoder,
head_config=self.config.model.head,
dtype=model_dtype,
)
def init_from_bert_train_state(
self, train_state: Any, restored_train_state: Any,
restored_model_cfg: ml_collections.ConfigDict) -> Any:
"""Updates the train_state with data from restored_train_state."""
raise NotImplementedError
class BERTClassificationModel(ClassificationModel):
"""BERT Classification model."""
def build_flax_model(self):
model_dtype = getattr(jnp, self.config.get('model_dtype_str', 'float32'))
with self.config.unlocked():
# Add vocabulary information from dataset meta-data to configs:
self.config.model.stem.vocab_size = self.dataset_meta_data['vocab_size']
self.config.model.stem.type_vocab_size = self.dataset_meta_data[
'type_vocab_size']
self.config.model.head.num_classes = self.dataset_meta_data['num_classes']
return BERT(
stem_config=self.config.model.stem,
encoder_config=self.config.model.encoder,
head_config=self.config.model.head,
dtype=model_dtype,
)
def init_from_train_state(
self, train_state: Any, restored_train_state: Any,
restored_model_cfg: ml_collections.ConfigDict) -> Any:
"""Updates the train_state with data from restored_train_state."""
return init_bert_from_train_state(train_state, restored_train_state,
self.config, restored_model_cfg)
class BERTRegressionModel(RegressionModel):
"""BERT Regression model."""
def build_flax_model(self):
model_dtype = getattr(jnp, self.config.get('model_dtype_str', 'float32'))
with self.config.unlocked():
# Add vocabulary information from dataset meta-data to configs:
self.config.model.stem.vocab_size = self.dataset_meta_data['vocab_size']
self.config.model.stem.type_vocab_size = self.dataset_meta_data[
'type_vocab_size']
return BERT(
stem_config=self.config.model.stem,
encoder_config=self.config.model.encoder,
head_config=self.config.model.head,
dtype=model_dtype,
)
def init_from_train_state(
self, train_state: Any, restored_train_state: Any,
restored_model_cfg: ml_collections.ConfigDict) -> Any:
"""Updates the train_state with data from restored_train_state."""
return init_bert_from_train_state(train_state, restored_train_state,
self.config, restored_model_cfg)
def init_bert_from_train_state(
train_state: Any, restored_train_state: Any,
config: ml_collections.ConfigDict,
restored_model_cfg: ml_collections.ConfigDict) -> Any:
"""Updates the train_state with data from restored_train_state."""
del restored_model_cfg
def _get_param_dict(params):
return {
'/'.join([str(kk)
for kk in k]): v
for k, v in flax.traverse_util.flatten_dict(params).items()
}
if hasattr(train_state, 'optimizer'):
# TODO(dehghani): Remove support for flax optim.
params = flax.core.unfreeze(train_state.optimizer.target)
restored_params = flax.core.unfreeze(
restored_train_state.optimizer.target)
else:
params = flax.core.unfreeze(train_state.params)
restored_params = flax.core.unfreeze(restored_train_state.params)
params_dict = _get_param_dict(params)
# Fix some names:
restored_params_dict = dict()
for key, value in flax.traverse_util.flatten_dict(restored_params).items():
name = '/'.join([str(k) for k in key])
if config.init_from.restore_next_sentence_prediction_head_params:
name = name.replace('next_sentence_prediction_head',
'classification_head')
restored_params_dict[name] = value
# Copy parameters over:
for pname, pvalue in restored_params_dict.items():
if 'masked_language_model_head' in pname:
# We throw away parameters of `masked_language_model_head`, but
# for the `next_sentence_prediction_head`, we only discard the final
# dense (`output_projection`) tha maps model representation to the
# label space.
continue
if (not config.init_from.restore_next_sentence_prediction_head_params and
'next_sentence_prediction_head' in pname):
continue
if 'output_projection' in pname:
continue
elif pname in params_dict:
params_dict[pname] = pvalue
else:
logging.error("Restored key doesn't exist in the model: %s.", pname)
logging.info('Inspect missing keys from the restored params:\n%s',
params_dict.keys() - restored_params_dict.keys())
logging.info('Inspect extra keys the the restored params:\n%s',
restored_params_dict.keys() - params_dict.keys())
splitkeys = {tuple(k.split('/')): v for k, v in params_dict.items()}
params = flax.traverse_util.unflatten_dict(splitkeys)
logging.info('Parameter summary after initialising from train state:')
debug_utils.log_param_shapes(params)
if hasattr(train_state, 'optimizer'):
# TODO(dehghani): Remove support for flax optim.
return train_state.replace(
optimizer=train_state.optimizer.replace(
target=flax.core.freeze(params)))
else:
return train_state.replace(params=flax.core.freeze(params))