owlv2 / scenic /projects /baselines /bert /bert_base_model.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for models working with bert."""
from typing import Callable, Dict, Optional, Tuple, Union
from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
from scenic.model_lib.base_models import base_model
from scenic.model_lib.base_models import model_utils
# Aliases for custom types:
Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[[Dict[str, jnp.ndarray], Batch], Dict[str, Tuple[float,
int]]]
LossFn = Callable[[Dict[str, jnp.ndarray], Batch, Optional[jnp.ndarray]], float]
def num_examples(
logits: jnp.ndarray,
weights: Optional[jnp.ndarray] = None) -> Union[jnp.ndarray, int]:
if weights is None:
return logits.shape[0]
return weights.sum()
def sparse_weighted_unnormalized_softmax_cross_entropy(
logits: jnp.ndarray,
labels: jnp.ndarray,
mlm_weights: jnp.ndarray,
batch_mask_weights: Optional[jnp.ndarray] = None) -> jnp.ndarray:
"""Computes sparse weighted softmax cross entropy give logits and targets.
Args:
logits: Logits of shape [batch_size, length, vocab_size].
labels: Labels from {0 ... vocab_size - 1} of shape [batch_size, length].
mlm_weights: Weights of shape [batch_size, length], indicating masked tokens
in masked language modeling task.
batch_mask_weights: None or array of shape [batch,] indicating masked
examples.
Returns:
Per example Loss value.
"""
batch_size, length, vocab_size = logits.shape
logits = jax.nn.log_softmax(logits)
logits, mlm_weights = logits.ravel(), mlm_weights.ravel()
offsets = (np.arange(batch_size * length) * vocab_size).reshape((-1, length))
labels = (labels + offsets).ravel()
loss = -jnp.take(logits, labels, axis=0)
loss = loss * mlm_weights
loss = loss.sum(axis=-1, keepdims=True) / (
mlm_weights.sum(axis=-1, keepdims=True) + 1e-8
)
if batch_mask_weights is not None:
loss = model_utils.apply_weights(loss, batch_mask_weights)
return loss
def sparse_weighted_softmax_cross_entropy(
logits: jnp.ndarray,
labels: jnp.ndarray,
mlm_weights: jnp.ndarray,
batch_mask_weights: Optional[jnp.ndarray] = None) -> jnp.ndarray:
"""Same as weighted_unnormalized, but additionally takes a mean.
Args:
logits: Logits of shape [batch_size, length, vocab_size].
labels: Labels from {0 ... vocab_size - 1} of shape [batch_size, length].
mlm_weights: Weights of shape [batch_size, length], indicating masked tokens
in masked language modeling task.
batch_mask_weights: None or array of shape [batch,] indicating masked
examples.
Returns:
The mean cross entropy of the examples in the given batch as a scalar.
"""
if batch_mask_weights is not None:
normalization = batch_mask_weights.sum()
else:
normalization = mlm_weights.shape[0] # Batch size.
sparse_unnormalized_softmax_ce = (
sparse_weighted_unnormalized_softmax_cross_entropy( # pylint: disable=line-too-long
logits, labels, mlm_weights, batch_mask_weights
)
)
return jnp.sum(sparse_unnormalized_softmax_ce) / (normalization + 1e-8)
def sparse_weighted_per_example_accuracy(
logits: jnp.ndarray,
labels: jnp.ndarray,
mlm_weights: jnp.ndarray,
batch_mask_weights: Optional[jnp.ndarray] = None) -> jnp.ndarray:
"""Computes weighted number of correctly classified over the given batch.
This computes the weighted number of correctly classified masked tokens in a
single, potentially padded minibatch. If the minibatch/inputs is padded (i.e.,
it contains null examples/pad pixels) it is assumed that batch_mask_weights
is a binary mask where 0 indicates that the example/pixel is null/padded.
We assume the trainer will aggregate and divide by number of samples.
Args:
logits: Logits of shape [batch_size, length, vocab_size].
labels: Labels from {0 ... vocab_size - 1} of shape [batch_size, length].
mlm_weights: Weights of shape [batch_size, length], indicating masked tokens
in masked language modeling task.
batch_mask_weights: None or array of shape [batch,] indicating masked
examples.
Returns:
Per example accuracy of predicted masked tokens.
"""
preds = jnp.argmax(logits, axis=-1)
correct = jnp.equal(preds, labels)
correct = correct * mlm_weights
# Shape of per example acccuracy will be (batch_size,).
per_ex_accuracy = correct.sum(axis=-1) / (mlm_weights.sum(axis=-1) + 1e-8)
if batch_mask_weights is not None:
per_ex_accuracy = model_utils.apply_weights(per_ex_accuracy,
batch_mask_weights)
return per_ex_accuracy
def bert_metrics_function(outputs: Dict[str, jnp.ndarray],
batch: Batch) -> Dict[str, Tuple[float, int]]:
"""Calcualte metrics for the BERT task.
Args:
outputs: Output of model that has masked LM logits of shape [batch, length,
vocab_size], and next sentence prediction logits of shape [batch, 2].
batch: Batch of data that has 'masked_lm_ids', 'masked_lm_weights' and
'next_sentence_labels'.
Returns:
A dict of metrics, in which keys are metrics name and values are tuples of
(metric, normalizer).
"""
mlm_logits = outputs['mlm_logits']
nsp_logits = outputs['nsp_logits']
next_sentence_labels = common_utils.onehot(batch['next_sentence_labels'], 2)
batch_weights = batch.get('batch_mask') # batch_mask might not be defined
per_ex_nsp_loss = model_utils.weighted_unnormalized_softmax_cross_entropy(
nsp_logits, next_sentence_labels, batch_weights)
per_ex_nsp_accuracy = model_utils.weighted_correctly_classified(
nsp_logits, next_sentence_labels, batch_weights)
per_ex_mlm_loss = sparse_weighted_unnormalized_softmax_cross_entropy(
mlm_logits, batch['masked_lm_ids'], batch['masked_lm_weights'],
batch_weights)
per_ex_mlm_accuracy = sparse_weighted_per_example_accuracy(
mlm_logits, batch['masked_lm_ids'], batch['masked_lm_weights'],
batch_weights)
# This psum is required to correctly evaluate with multihost. Only host 0
# will report the metrics, so we must aggregate across all hosts. The psum
# will map an array of shape [n_global_devices, batch_size] -> [batch_size]
# by summing across the devices dimension. The outer sum then sums across the
# batch dim. The result is then we have summed across all samples in the
# sharded batch.
evaluated_metrics = {}
normalizer = num_examples(mlm_logits, batch_weights)
for name, value in zip(
['nsp_loss', 'nsp_accuracy', 'mlm_loss', 'mlm_accuracy', 'loss'], [
per_ex_nsp_loss, per_ex_nsp_accuracy, per_ex_mlm_loss,
per_ex_mlm_accuracy, per_ex_nsp_loss + per_ex_mlm_loss
]):
evaluated_metrics[name] = model_utils.psum_metric_normalizer(
(value, normalizer))
return evaluated_metrics # pytype: disable=bad-return-type # jax-ndarray
def compute_bert_loss(mlm_logits: jnp.ndarray, nsp_logits: jnp.ndarray,
batch: Batch) -> float:
"""Computes BERT loss.
Args:
mlm_logits: Masked LM logits of shape [batch, length, vocab_size].
nsp_logits: Next sentence prediction logits of shape [batch, 2].
batch: Batch of data that has 'masked_lm_ids', 'masked_lm_weights' and
'next_sentence_labels'.
Returns:
Loss value.
"""
next_sentence_labels = common_utils.onehot(batch['next_sentence_labels'], 2)
batch_weights = batch.get('batch_mask') # batch_mask might not be defined
nsp_loss = model_utils.weighted_softmax_cross_entropy(nsp_logits,
next_sentence_labels,
batch_weights)
mlm_loss = sparse_weighted_softmax_cross_entropy(mlm_logits,
batch['masked_lm_ids'],
batch['masked_lm_weights'],
batch_weights)
return nsp_loss + mlm_loss # pytype: disable=bad-return-type # jax-ndarray
class BERTBaseModel(base_model.BaseModel):
"""Defines BERT base models.
A model is class with three members: get_metrics_fn, loss_fn, and a
flax_model.
get_metrics_fn returns a callable function, metric_fn, that calculates the
metrics and returns a dictionary. The metric function computes f(x_i, y_i) on
a minibatch, it has API:
```metric_fn(logits, label, weights).```
The trainer will then aggregate and compute the mean across all samples
evaluated.
loss_fn is a function of API
loss = loss_fn(logits, batch, model_params=None).
This model class defines a softmax_cross_entropy_loss with weight decay,
where the weight decay factor is determined by config.l2_decay_factor.
flax_model is returned from the build_flax_model function. A typical
usage pattern will be:
```
model_cls = bert_model.BERTModel
model = model_cls(config, dataset.meta_data)
flax_model = model.build_flax_model
dummy_input = {name: jnp.zeros(input_shape, model_input_dtype), ...}
model_state, params = flax_model.init(
rng, dummy_input, train=False).pop('params')
```
And this is how to call the model:s
```
variables = {'params': params, **model_state}
output, new_model_state = flax_model.apply(variables, inputs, ...)
```
"""
def get_metrics_fn(self, split: Optional[str] = None) -> MetricFn: # pytype: disable=signature-mismatch # jax-ndarray
"""Returns a callable metric function for the model.
Args:
split: The split for which we calculate the metrics. It should be one of
the ['train', 'validation', 'test'].
Returns: A metric function with the following API: ```metrics_fn(outputs,
batch)```
"""
del split # For all splits, we return the same metric functions.
return bert_metrics_function
def loss_function(self,
outputs: Dict[str, jnp.ndarray],
batch: Batch,
model_params: Optional[jnp.ndarray] = None) -> float:
"""Returns softmax cross entropy loss with an L2 penalty on the weights.
Args:
outputs: a dictionary containing either 'logits' key of shape [batch,
length, num_classes] or 'nsp_logits' of shape [batch, 2] and
'mlm_logits' of shape [batch, length, vocab_size] (for 'BERT' task).
batch: Batch of data that has 'label' and optionally 'batch_mask'.
model_params: Parameters of the model, for optionally applying
regularization.
Returns:
Total loss.
"""
total_loss = compute_bert_loss(outputs['mlm_logits'], outputs['nsp_logits'],
batch)
if self.config.get('l2_decay_factor'):
l2_loss = model_utils.l2_regularization(model_params)
total_loss += 0.5 * self.config.l2_decay_factor * l2_loss
return total_loss
def build_flax_model(self):
raise NotImplementedError('Subclasses must implement build_flax_model().')
def default_flax_model_config(self):
"""Default config for the flax model that is built in `build_flax_model`.
This function in particular serves the testing functions and supposed to
provide config tha are passed to the flax_model when it's build in
`build_flax_model` function, e.g., `model_dtype_str`.
"""
raise NotImplementedError(
'Subclasses must implement default_flax_model_config().')