| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from typing import Callable, List, Optional, Tuple |
|
|
| import flax |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
| from flax.linen.attention import dot_product_attention_weights |
| from flax.traverse_util import flatten_dict, unflatten_dict |
|
|
| from ...modeling_flax_outputs import ( |
| FlaxBaseModelOutput, |
| FlaxBaseModelOutputWithPooling, |
| FlaxMaskedLMOutput, |
| FlaxSequenceClassifierOutput, |
| ) |
| from ...modeling_flax_utils import ( |
| ACT2FN, |
| FlaxPreTrainedModel, |
| append_replace_return_docstrings, |
| overwrite_call_docstring, |
| ) |
| from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward |
| from .configuration_beit import BeitConfig |
|
|
|
|
| @flax.struct.dataclass |
| class FlaxBeitModelOutputWithPooling(FlaxBaseModelOutputWithPooling): |
| """ |
| Class for outputs of [`FlaxBeitModel`]. |
| |
| Args: |
| last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`): |
| Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if |
| *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token |
| will be returned. |
| hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape |
| `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus |
| the initial embedding outputs. |
| attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in |
| the self-attention heads. |
| """ |
|
|
|
|
| BEIT_START_DOCSTRING = r""" |
| |
| This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading, saving and converting weights from PyTorch models) |
| |
| This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) |
| subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to |
| general usage and behavior. |
| |
| Finally, this model supports inherent JAX features such as: |
| |
| - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) |
| - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) |
| - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) |
| - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) |
| |
| Parameters: |
| config ([`BeitConfig`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
| dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
| The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
| `jax.numpy.bfloat16` (on TPUs). |
| |
| This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
| specified all the computation will be performed with the given `dtype`. |
| |
| **Note that this only specifies the dtype of the computation and does not influence the dtype of model |
| parameters.** |
| |
| If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
| [`~FlaxPreTrainedModel.to_bf16`]. |
| """ |
|
|
| BEIT_INPUTS_DOCSTRING = r""" |
| Args: |
| pixel_values (`numpy.ndarray` of shape `(batch_size, num_channels, height, width)`): |
| Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See |
| [`AutoImageProcessor.__call__`] for details. |
| |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| def relative_position_index_init(window_size: Tuple[int, int]) -> jnp.ndarray: |
| """ |
| get pair-wise relative position index for each token inside the window |
| """ |
| num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 |
|
|
| coords_h = np.arange(window_size[0]) |
| coords_w = np.arange(window_size[1]) |
| coords = np.stack(np.meshgrid(coords_h, coords_w, indexing="ij")) |
| coords_flatten = np.reshape(coords, (2, -1)) |
| relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] |
| relative_coords = np.transpose(relative_coords, (1, 2, 0)) |
| relative_coords[:, :, 0] += window_size[0] - 1 |
| relative_coords[:, :, 1] += window_size[1] - 1 |
| relative_coords[:, :, 0] *= 2 * window_size[1] - 1 |
|
|
| relative_position_index = np.zeros(shape=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) |
| relative_position_index[1:, 1:] = relative_coords.sum(-1) |
| relative_position_index[0, 0:] = num_relative_distance - 3 |
| relative_position_index[0:, 0] = num_relative_distance - 2 |
| relative_position_index[0, 0] = num_relative_distance - 1 |
| return jnp.array(relative_position_index) |
|
|
|
|
| def ones_with_scale(key, shape, scale, dtype=jnp.float32): |
| return jnp.ones(shape, dtype) * scale |
|
|
|
|
| class FlaxBeitDropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
|
| rate: float |
|
|
| @nn.module.compact |
| def __call__(self, inputs, deterministic: Optional[bool] = True): |
| if self.rate == 0.0: |
| return inputs |
| keep_prob = 1.0 - self.rate |
| if deterministic: |
| return inputs |
| else: |
| shape = (inputs.shape[0],) + (1,) * (inputs.ndim - 1) |
| rng = self.make_rng("droppath") |
| random_tensor = keep_prob + jax.random.uniform(rng, shape=shape, dtype=inputs.dtype) |
| binary_tensor = jnp.floor(random_tensor) |
| output = inputs / keep_prob * binary_tensor |
| return output |
|
|
|
|
| class FlaxBeitPatchEmbeddings(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.num_channels = self.config.num_channels |
| image_size = self.config.image_size |
| patch_size = self.config.patch_size |
| num_patches = (image_size // patch_size) * (image_size // patch_size) |
| patch_shape = (image_size // patch_size, image_size // patch_size) |
| self.num_patches = num_patches |
| self.patch_shape = patch_shape |
| self.projection = nn.Conv( |
| self.config.hidden_size, |
| kernel_size=(patch_size, patch_size), |
| strides=(patch_size, patch_size), |
| padding="VALID", |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| ) |
|
|
| def __call__(self, pixel_values): |
| num_channels = pixel_values.shape[-1] |
| if num_channels != self.num_channels: |
| raise ValueError( |
| "Make sure that the channel dimension of the pixel values match with the one set in the configuration." |
| ) |
| embeddings = self.projection(pixel_values) |
| batch_size, _, _, channels = embeddings.shape |
| return jnp.reshape(embeddings, (batch_size, -1, channels)) |
|
|
|
|
| class FlaxBeitEmbeddings(nn.Module): |
| """Construct the CLS token, position and patch embeddings.""" |
|
|
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) |
| if self.config.use_mask_token: |
| self.mask_token = self.param("mask_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) |
| self.patch_embeddings = FlaxBeitPatchEmbeddings(self.config, dtype=self.dtype) |
| num_patches = self.patch_embeddings.num_patches |
| if self.config.use_absolute_position_embeddings: |
| self.position_embeddings = self.param( |
| "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) |
| ) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) |
|
|
| def __call__(self, pixel_values, bool_masked_pos=None, deterministic=True): |
| embeddings = self.patch_embeddings(pixel_values) |
| batch_size, seq_len, _ = embeddings.shape |
|
|
| cls_tokens = jnp.broadcast_to(self.cls_token, (batch_size, 1, self.config.hidden_size)) |
| cls_tokens = cls_tokens.astype(embeddings.dtype) |
|
|
| if bool_masked_pos is not None: |
| mask_tokens = jnp.broadcast_to(self.mask_token, (batch_size, seq_len, self.config.hidden_size)) |
| mask_tokens = mask_tokens.astype(embeddings.dtype) |
| |
| w = jnp.expand_dims(bool_masked_pos, axis=-1) |
| embeddings = embeddings * (1 - w) + mask_tokens * w |
|
|
| embeddings = jnp.concatenate((cls_tokens, embeddings), axis=1) |
|
|
| if self.config.use_absolute_position_embeddings: |
| embeddings = embeddings + self.position_embeddings.astype(embeddings.dtype) |
|
|
| embeddings = self.dropout(embeddings, deterministic=deterministic) |
| return embeddings |
|
|
|
|
| class FlaxBeitRelativePositionBias(nn.Module): |
| config: BeitConfig |
| window_size: Tuple[int, int] |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| num_relative_distance = (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) + 3 |
| self.relative_position_bias_table = self.param( |
| "relative_position_bias_table", |
| nn.initializers.zeros, |
| (num_relative_distance, self.config.num_attention_heads), |
| ) |
| |
|
|
| self.relative_position_index = relative_position_index_init(self.window_size) |
|
|
| def __call__(self): |
| index = self.relative_position_index.reshape(-1) |
| shape = (self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1) |
| relative_position_bias = self.relative_position_bias_table[index].reshape(shape) |
| return jnp.transpose(relative_position_bias, (2, 0, 1)) |
|
|
|
|
| class FlaxBeitSelfAttention(nn.Module): |
| config: BeitConfig |
| window_size: Tuple[int, int] |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| if self.config.hidden_size % self.config.num_attention_heads != 0 and not hasattr( |
| self.config, "embedding_size" |
| ): |
| raise ValueError( |
| f"The hidden size {self.config.hidden_size,} is not a multiple of the number of attention " |
| f"heads {self.config.num_attention_heads}." |
| ) |
|
|
| self.query = nn.Dense( |
| self.config.hidden_size, |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| ) |
| self.key = nn.Dense( |
| self.config.hidden_size, |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| use_bias=False, |
| ) |
| self.value = nn.Dense( |
| self.config.hidden_size, |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| ) |
|
|
| self.relative_position_bias = ( |
| FlaxBeitRelativePositionBias(self.config, window_size=self.window_size, dtype=self.dtype) |
| if self.window_size |
| else None |
| ) |
|
|
| def __call__( |
| self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False |
| ): |
| head_dim = self.config.hidden_size // self.config.num_attention_heads |
|
|
| query_states = self.query(hidden_states).reshape( |
| hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) |
| ) |
| value_states = self.value(hidden_states).reshape( |
| hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) |
| ) |
| key_states = self.key(hidden_states).reshape( |
| hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim) |
| ) |
|
|
| dropout_rng = None |
| if not deterministic and self.config.attention_probs_dropout_prob > 0.0: |
| dropout_rng = self.make_rng("dropout") |
|
|
| attention_bias = jnp.array(0.0, dtype=self.dtype) |
| |
| if self.relative_position_bias is not None: |
| attention_bias = jnp.expand_dims(self.relative_position_bias(), 0) |
| attention_bias = attention_bias.astype(query_states.dtype) |
|
|
| |
| if relative_position_bias is not None: |
| attention_bias = attention_bias + relative_position_bias.astype(attention_bias.dtype) |
|
|
| attn_weights = dot_product_attention_weights( |
| query_states, |
| key_states, |
| bias=attention_bias, |
| dropout_rng=dropout_rng, |
| dropout_rate=self.config.attention_probs_dropout_prob, |
| broadcast_dropout=True, |
| deterministic=deterministic, |
| dtype=self.dtype, |
| precision=None, |
| ) |
|
|
| attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
| attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) |
|
|
| outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) |
| return outputs |
|
|
|
|
| class FlaxBeitSelfOutput(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.dense = nn.Dense( |
| self.config.hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) |
|
|
| def __call__(self, hidden_states, deterministic: bool = True): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
| return hidden_states |
|
|
|
|
| class FlaxBeitAttention(nn.Module): |
| config: BeitConfig |
| window_size: Tuple[int, int] |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.attention = FlaxBeitSelfAttention(self.config, self.window_size, dtype=self.dtype) |
| self.output = FlaxBeitSelfOutput(self.config, dtype=self.dtype) |
|
|
| def __call__( |
| self, hidden_states, relative_position_bias=None, deterministic=True, output_attentions: bool = False |
| ): |
| attn_outputs = self.attention( |
| hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions |
| ) |
| attn_output = attn_outputs[0] |
| attn_output = self.output(attn_output, deterministic=deterministic) |
|
|
| outputs = (attn_output,) |
|
|
| if output_attentions: |
| outputs += (attn_outputs[1],) |
|
|
| return outputs |
|
|
|
|
| class FlaxBeitIntermediate(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.dense = nn.Dense( |
| self.config.intermediate_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.activation = ACT2FN[self.config.hidden_act] |
|
|
| def __call__(self, hidden_states): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.activation(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class FlaxBeitOutput(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.dense = nn.Dense( |
| self.config.hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) |
|
|
| def __call__(self, hidden_states, deterministic: bool = True): |
| hidden_states = self.dense(hidden_states) |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
| return hidden_states |
|
|
|
|
| class FlaxBeitLayer(nn.Module): |
| config: BeitConfig |
| window_size: Tuple[int, int] |
| drop_path_rate: float |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype) |
| self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype) |
| self.output = FlaxBeitOutput(self.config, dtype=self.dtype) |
| self.layernorm_before = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate) |
| self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
| self.init_values = self.config.layer_scale_init_value |
| if self.init_values > 0: |
| self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values) |
| self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values) |
| else: |
| self.lambda_1 = None |
| self.lambda_2 = None |
|
|
| def __call__( |
| self, hidden_states, relative_position_bias=None, deterministic: bool = True, output_attentions: bool = False |
| ): |
| self_attention_outputs = self.attention( |
| self.layernorm_before(hidden_states), |
| relative_position_bias, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| ) |
| attention_output = self_attention_outputs[0] |
|
|
| |
| if self.lambda_1 is not None: |
| attention_output = self.lambda_1.astype(attention_output.dtype) * attention_output |
|
|
| |
| hidden_states = self.drop_path(attention_output, deterministic=deterministic) + hidden_states |
|
|
| |
| layer_output = self.layernorm_after(hidden_states) |
|
|
| layer_output = self.intermediate(layer_output) |
| layer_output = self.output(layer_output, deterministic=deterministic) |
|
|
| |
| if self.lambda_2 is not None: |
| layer_output = self.lambda_2.astype(layer_output.dtype) * layer_output |
|
|
| |
| layer_output = self.drop_path(layer_output, deterministic=deterministic) + hidden_states |
|
|
| outputs = (layer_output,) |
|
|
| if output_attentions: |
| outputs += (self_attention_outputs[1],) |
|
|
| return outputs |
|
|
|
|
| class FlaxBeitLayerCollection(nn.Module): |
| config: BeitConfig |
| window_size: Tuple[int, int] |
| drop_path_rates: List[float] |
| relative_position_bias: Callable[[], jnp.ndarray] |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.layers = [ |
| FlaxBeitLayer( |
| self.config, |
| window_size=self.window_size if self.config.use_relative_position_bias else None, |
| drop_path_rate=self.drop_path_rates[i], |
| name=str(i), |
| dtype=self.dtype, |
| ) |
| for i in range(self.config.num_hidden_layers) |
| ] |
|
|
| def __call__( |
| self, |
| hidden_states, |
| deterministic: bool = True, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| all_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
|
|
| for i, layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| relative_position_bias = self.relative_position_bias() if self.relative_position_bias is not None else None |
| layer_outputs = layer( |
| hidden_states, relative_position_bias, deterministic=deterministic, output_attentions=output_attentions |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions += (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| outputs = (hidden_states,) |
| if not return_dict: |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
| ) |
|
|
|
|
| class FlaxBeitEncoder(nn.Module): |
| config: BeitConfig |
| window_size: Tuple[int, int] |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| if self.config.use_shared_relative_position_bias: |
| self.relative_position_bias = FlaxBeitRelativePositionBias( |
| config=self.config, window_size=self.window_size, dtype=self.dtype |
| ) |
|
|
| |
| drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)) |
| self.layer = FlaxBeitLayerCollection( |
| self.config, |
| window_size=self.window_size, |
| drop_path_rates=drop_path_rates, |
| relative_position_bias=self.relative_position_bias |
| if self.config.use_shared_relative_position_bias |
| else None, |
| dtype=self.dtype, |
| ) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| deterministic: bool = True, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| return self.layer( |
| hidden_states, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
|
|
| class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = BeitConfig |
| base_model_prefix = "beit" |
| main_input_name = "pixel_values" |
| module_class: nn.Module = None |
|
|
| def __init__( |
| self, |
| config: BeitConfig, |
| input_shape=None, |
| seed: int = 0, |
| dtype: jnp.dtype = jnp.float32, |
| _do_init: bool = True, |
| **kwargs, |
| ): |
| module = self.module_class(config=config, dtype=dtype, **kwargs) |
| if input_shape is None: |
| input_shape = (1, config.image_size, config.image_size, config.num_channels) |
| super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
| def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
| |
| pixel_values = jnp.zeros(input_shape, dtype=self.dtype) |
|
|
| params_rng, dropout_rng = jax.random.split(rng) |
| dropout_rng, droppath_rng = jax.random.split(dropout_rng) |
| rngs = {"params": params_rng, "dropout": dropout_rng, "droppath": droppath_rng} |
|
|
| random_params = self.module.init(rngs, pixel_values, return_dict=False)["params"] |
|
|
| if params is not None: |
| random_params = flatten_dict(unfreeze(random_params)) |
| params = flatten_dict(unfreeze(params)) |
| for missing_key in self._missing_keys: |
| params[missing_key] = random_params[missing_key] |
| self._missing_keys = set() |
| return freeze(unflatten_dict(params)) |
| else: |
| return random_params |
|
|
| @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) |
| def __call__( |
| self, |
| pixel_values, |
| bool_masked_pos=None, |
| params: dict = None, |
| dropout_rng: jax.random.PRNGKey = None, |
| train: bool = False, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) |
| |
| rngs = {} |
| if dropout_rng is not None: |
| dropout_rng, droppath_rng = jax.random.split(dropout_rng) |
| rngs["dropout"] = dropout_rng |
| rngs["droppath"] = droppath_rng |
|
|
| return self.module.apply( |
| {"params": params or self.params}, |
| jnp.array(pixel_values, dtype=jnp.float32), |
| bool_masked_pos, |
| not train, |
| output_attentions, |
| output_hidden_states, |
| return_dict, |
| rngs=rngs, |
| ) |
|
|
|
|
| class FlaxBeitPooler(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| if self.config.use_mean_pooling: |
| self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
| def __call__(self, hidden_states): |
| if self.config.use_mean_pooling: |
| |
| patch_tokens = hidden_states[:, 1:, :] |
| pooled_output = self.layernorm(jnp.mean(patch_tokens, axis=1)) |
| else: |
| |
| pooled_output = hidden_states[:, 0] |
|
|
| return pooled_output |
|
|
|
|
| class FlaxBeitModule(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
| add_pooling_layer: bool = True |
|
|
| def setup(self): |
| self.embeddings = FlaxBeitEmbeddings(self.config, dtype=self.dtype) |
| self.encoder = FlaxBeitEncoder( |
| self.config, window_size=self.embeddings.patch_embeddings.patch_shape, dtype=self.dtype |
| ) |
| if not self.config.use_mean_pooling: |
| self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.pooler = FlaxBeitPooler(self.config, dtype=self.dtype) if self.add_pooling_layer else None |
|
|
| def __call__( |
| self, |
| pixel_values, |
| bool_masked_pos=None, |
| deterministic: bool = True, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| hidden_states = self.embeddings(pixel_values, bool_masked_pos, deterministic=deterministic) |
|
|
| outputs = self.encoder( |
| hidden_states, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| hidden_states = outputs[0] |
| if not self.config.use_mean_pooling: |
| hidden_states = self.layernorm(hidden_states) |
| pooled = self.pooler(hidden_states) if self.add_pooling_layer else None |
|
|
| if not return_dict: |
| |
| if pooled is None: |
| return (hidden_states,) + outputs[1:] |
| return (hidden_states, pooled) + outputs[1:] |
|
|
| return FlaxBeitModelOutputWithPooling( |
| last_hidden_state=hidden_states, |
| pooler_output=pooled, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.", |
| BEIT_START_DOCSTRING, |
| ) |
| class FlaxBeitModel(FlaxBeitPreTrainedModel): |
| module_class = FlaxBeitModule |
|
|
|
|
| FLAX_BEIT_MODEL_DOCSTRING = """ |
| Returns: |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import AutoImageProcessor, FlaxBeitModel |
| >>> from PIL import Image |
| >>> import requests |
| |
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| >>> image = Image.open(requests.get(url, stream=True).raw) |
| |
| >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") |
| >>> model = FlaxBeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k") |
| |
| >>> inputs = image_processor(images=image, return_tensors="np") |
| >>> outputs = model(**inputs) |
| >>> last_hidden_states = outputs.last_hidden_state |
| ``` |
| """ |
|
|
| overwrite_call_docstring(FlaxBeitModel, FLAX_BEIT_MODEL_DOCSTRING) |
| append_replace_return_docstrings(FlaxBeitModel, output_type=FlaxBeitModelOutputWithPooling, config_class=BeitConfig) |
|
|
|
|
| class FlaxBeitForMaskedImageModelingModule(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype) |
|
|
| |
| self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.lm_head = nn.Dense( |
| self.config.vocab_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
|
|
| def __call__( |
| self, |
| pixel_values=None, |
| bool_masked_pos=None, |
| deterministic: bool = True, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.beit( |
| pixel_values, |
| bool_masked_pos, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
| sequence_output = self.layernorm(sequence_output) |
| prediction_scores = self.lm_head(sequence_output[:, 1:]) |
|
|
| if not return_dict: |
| output = (prediction_scores,) + outputs[2:] |
| return output |
|
|
| return FlaxMaskedLMOutput( |
| logits=prediction_scores, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).", |
| BEIT_START_DOCSTRING, |
| ) |
| class FlaxBeitForMaskedImageModeling(FlaxBeitPreTrainedModel): |
| module_class = FlaxBeitForMaskedImageModelingModule |
|
|
|
|
| FLAX_BEIT_MLM_DOCSTRING = """ |
| bool_masked_pos (`numpy.ndarray` of shape `(batch_size, num_patches)`): |
| Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). |
| |
| Returns: |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import AutoImageProcessor, BeitForMaskedImageModeling |
| >>> from PIL import Image |
| >>> import requests |
| |
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| >>> image = Image.open(requests.get(url, stream=True).raw) |
| |
| >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224-pt22k") |
| >>> model = BeitForMaskedImageModeling.from_pretrained("microsoft/beit-base-patch16-224-pt22k") |
| |
| >>> inputs = image_processor(images=image, return_tensors="np") |
| >>> outputs = model(**inputs) |
| >>> logits = outputs.logits |
| ``` |
| """ |
|
|
| overwrite_call_docstring(FlaxBeitForMaskedImageModeling, FLAX_BEIT_MLM_DOCSTRING) |
| append_replace_return_docstrings( |
| FlaxBeitForMaskedImageModeling, output_type=FlaxMaskedLMOutput, config_class=BeitConfig |
| ) |
|
|
|
|
| class FlaxBeitForImageClassificationModule(nn.Module): |
| config: BeitConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.beit = FlaxBeitModule(config=self.config, dtype=self.dtype, add_pooling_layer=True) |
| self.classifier = nn.Dense( |
| self.config.num_labels, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
|
|
| def __call__( |
| self, |
| pixel_values=None, |
| bool_masked_pos=None, |
| deterministic: bool = True, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.beit( |
| pixel_values, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = outputs[1] |
| logits = self.classifier(pooled_output) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return output |
|
|
| return FlaxSequenceClassifierOutput( |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| @add_start_docstrings( |
| """ |
| Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final |
| hidden states of the patch tokens) e.g. for ImageNet. |
| """, |
| BEIT_START_DOCSTRING, |
| ) |
| class FlaxBeitForImageClassification(FlaxBeitPreTrainedModel): |
| module_class = FlaxBeitForImageClassificationModule |
|
|
|
|
| FLAX_BEIT_CLASSIF_DOCSTRING = """ |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoImageProcessor, FlaxBeitForImageClassification |
| >>> from PIL import Image |
| >>> import requests |
| |
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| >>> image = Image.open(requests.get(url, stream=True).raw) |
| |
| >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") |
| >>> model = FlaxBeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224") |
| |
| >>> inputs = image_processor(images=image, return_tensors="np") |
| >>> outputs = model(**inputs) |
| >>> logits = outputs.logits |
| >>> # model predicts one of the 1000 ImageNet classes |
| >>> predicted_class_idx = logits.argmax(-1).item() |
| >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) |
| ``` |
| """ |
|
|
| overwrite_call_docstring(FlaxBeitForImageClassification, FLAX_BEIT_CLASSIF_DOCSTRING) |
| append_replace_return_docstrings( |
| FlaxBeitForImageClassification, output_type=FlaxSequenceClassifierOutput, config_class=BeitConfig |
| ) |
|
|