| | import jax |
| | import jax.numpy as jnp |
| |
|
| | import flax |
| | import flax.linen as nn |
| | from flax.core.frozen_dict import FrozenDict, unfreeze |
| |
|
| | from typing import Any, Optional, Tuple |
| |
|
| | from transformers import ( |
| | GPT2Config) |
| |
|
| | from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward |
| | from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2BlockCollection |
| | from transformers.modeling_flax_outputs import FlaxBaseModelOutput |
| | from transformers.modeling_flax_utils import FlaxPreTrainedModel |
| | from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2Module |
| |
|
| | from transformers import GPT2Tokenizer |
| | tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>') |
| |
|
| | GPT2_START_DOCSTRING = r""" |
| | This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the |
| | generic methods the library implements for all its model (such as downloading or saving, resizing the input |
| | embeddings, pruning heads etc.) |
| | This model is also a Flax Linen `flax.nn.Module |
| | <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax |
| | Module and refer to the Flax documentation for all matter related to general usage and behavior. |
| | Finally, this model supports inherent JAX features such as: |
| | - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__ |
| | - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__ |
| | - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__ |
| | - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__ |
| | Parameters: |
| | config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model. |
| | Initializing with a config file does not load the weights associated with the model, only the |
| | configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the |
| | model weights. |
| | """ |
| | GPT2_INPUTS_DOCSTRING = r""" |
| | Args: |
| | input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size,input_ids_length)`): |
| | :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary. |
| | Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See |
| | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for |
| | details. |
| | `What are input IDs? <../glossary.html#input-ids>`__ |
| | attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | `What are attention masks? <../glossary.html#attention-mask>`__ |
| | position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`): |
| | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, |
| | config.max_position_embeddings - 1]``. |
| | past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``): |
| | Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
| | auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`. |
| | output_attentions (:obj:`bool`, `optional`): |
| | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned |
| | tensors for more detail. |
| | output_hidden_states (:obj:`bool`, `optional`): |
| | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for |
| | more detail. |
| | return_dict (:obj:`bool`, `optional`): |
| | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. |
| | """ |
| |
|
| | class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = GPT2Config |
| | base_model_prefix = "transformer" |
| | module_class: nn.Module = None |
| |
|
| | def __init__( |
| | self, |
| | config: GPT2Config, |
| | input_shape: Tuple = (1, 1), |
| | seed: int = 0, |
| | dtype: jnp.dtype = jnp.float32, |
| | **kwargs, |
| | ): |
| | module = self.module_class(config=config, dtype=dtype, **kwargs) |
| | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) |
| |
|
| | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
| | |
| | input_ids = jnp.zeros(input_shape, dtype="i4") |
| | attention_mask = jnp.ones_like(input_ids) |
| | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) |
| | params_rng, dropout_rng = jax.random.split(rng) |
| | rngs = {"params": params_rng, "dropout": dropout_rng} |
| |
|
| | return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] |
| |
|
| | def init_cache(self, batch_size, max_length): |
| | r""" |
| | Args: |
| | batch_size (:obj:`int`): |
| | batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
| | max_length (:obj:`int`): |
| | maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
| | cache. |
| | """ |
| | |
| | input_ids = jnp.ones((batch_size, max_length)) |
| | attention_mask = jnp.ones_like(input_ids) |
| | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
| |
|
| | init_variables = self.module.init( |
| | jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True |
| | ) |
| | return init_variables["cache"] |
| |
|
| | @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) |
| | def __call__( |
| | self, |
| | input_ids, |
| | attention_mask=None, |
| | position_ids=None, |
| | params: dict = None, |
| | past_key_values: dict = None, |
| | dropout_rng: jax.random.PRNGKey = None, |
| | train: bool = False, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ): |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| | |
| |
|
| | if position_ids is None: |
| | if past_key_values is not None: |
| | raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") |
| |
|
| | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) |
| |
|
| | if attention_mask is None: |
| | attention_mask = jnp.ones_like(input_ids) |
| |
|
| | |
| | rngs = {} |
| | if dropout_rng is not None: |
| | rngs["dropout"] = dropout_rng |
| |
|
| | inputs = {"params": params or self.params} |
| |
|
| | |
| | if past_key_values: |
| | inputs["cache"] = past_key_values |
| | mutable = ["cache"] |
| | else: |
| | mutable = False |
| |
|
| | outputs = self.module.apply( |
| | inputs, |
| | jnp.array(input_ids, dtype="i4"), |
| | jnp.array(attention_mask, dtype="i4"), |
| | jnp.array(position_ids, dtype="i4"), |
| | not train, |
| | False, |
| | output_attentions, |
| | output_hidden_states, |
| | return_dict, |
| | rngs=rngs, |
| | mutable=mutable, |
| | ) |
| |
|
| | |
| | if past_key_values is not None and return_dict: |
| | outputs, past_key_values = outputs |
| | outputs["past_key_values"] = unfreeze(past_key_values["cache"]) |
| | return outputs |
| | elif past_key_values is not None and not return_dict: |
| | outputs, past_key_values = outputs |
| | outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] |
| |
|
| | return outputs |
| |
|
| | class FlaxGPT2ForMultipleChoiceModule(nn.Module): |
| | config:GPT2Config |
| | dtype: jnp.dtype = jnp.float32 |
| | def setup(self): |
| | self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype) |
| | self.dropout = nn.Dropout(rate=0.2) |
| | self.classifier = nn.Dense(4, dtype=self.dtype) |
| |
|
| | def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args): |
| | batch_size = input_ids.shape[0] |
| | rng=jax.random.PRNGKey(0) |
| | _, dropout_rng = jax.random.split(rng) |
| | input_ids=input_ids.reshape(4*batch_size,-1) |
| | position_ids=position_ids.reshape(4*batch_size,-1) |
| | attention_mask=attention_mask.reshape(4*batch_size,-1) |
| |
|
| | outputs=self.transformer(input_ids, attention_mask,position_ids,return_dict=return_dict) |
| | |
| |
|
| | hidden_states = outputs[0] |
| | hidden_states= jnp.mean(hidden_states, axis=1) |
| |
|
| | |
| | |
| | hidden_states=hidden_states.reshape(batch_size,-1) |
| |
|
| | dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng) |
| |
|
| | |
| |
|
| | logits = self.classifier(dropout_output) |
| | reshaped_logits = logits.reshape(-1, 4) |
| | |
| | if not return_dict: |
| | return (reshaped_logits,) + outputs[2:] |
| | return reshaped_logits |
| |
|
| | class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel): |
| | module_class = FlaxGPT2ForMultipleChoiceModule |