| from typing import Optional, Tuple |
|
|
| import jax |
| import jax.numpy as jnp |
| from jax.random import PRNGKey |
| import flax.linen as nn |
| from flax.core.frozen_dict import FrozenDict, unfreeze |
|
|
| from transformers.modeling_flax_outputs import FlaxCausalLMOutputWithCrossAttentions |
| from transformers.file_utils import add_start_docstrings |
| from transformers.modeling_flax_utils import FlaxPreTrainedModel |
| from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerationModule |
|
|
| from model.vae import VAE |
| from model.outputs import TransformerVaeOutput |
| from model.config import T5VaeConfig |
|
|
|
|
| @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""") |
| class FlaxT5VaeForAutoencodingModule(nn.Module): |
| config: T5VaeConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def _get_encoder_module(self): |
| return self.t5.encoder |
|
|
| def _get_vae_encoder_module(self): |
| return self.vae.encoder |
|
|
| def _get_vae_decoder_module(self): |
| return self.vae.decoder |
|
|
| def _get_decoder_module(self): |
| return self.t5.decoder |
|
|
| def setup(self): |
| self.t5 = FlaxT5ForConditionalGenerationModule(self.config.t5) |
| self.vae = VAE(self.config) |
|
|
| def __call__( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| decoder_input_ids=None, |
| decoder_attention_mask=None, |
| encoder_outputs=None, |
| latent_codes=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| deterministic: bool = True, |
| ): |
| """ |
| Adapted from `FlaxT5ForConditionalGenerationModule` |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| encoder_outputs = self.t5.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=deterministic, |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| |
| hidden_states, latent_codes = self.vae(hidden_states, latent_codes) |
| encoder_attention_mask = jnp.ones((hidden_states.shape[0], hidden_states.shape[1])) |
|
|
| |
| decoder_outputs = self.t5.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| encoder_hidden_states=hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=deterministic, |
| ) |
|
|
| sequence_output = decoder_outputs[0] |
|
|
| if self.t5.config.tie_word_embeddings: |
| |
| |
| sequence_output = sequence_output * (self.t5.config.d_model ** -0.5) |
|
|
| if self.t5.config.tie_word_embeddings: |
| shared_embedding = self.t5.shared.variables["params"]["embedding"] |
| lm_logits = self.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) |
| else: |
| lm_logits = self.t5.lm_head(sequence_output) |
|
|
| if not return_dict: |
| return [lm_logits, latent_codes] + decoder_outputs[1:] + encoder_outputs |
|
|
| return TransformerVaeOutput( |
| logits=lm_logits, |
| latent_codes=latent_codes, |
| last_hidden_state=decoder_outputs.last_hidden_state, |
| past_key_values=decoder_outputs.past_key_values, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
| encoder_hidden_states=encoder_outputs.hidden_states, |
| encoder_attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = T5VaeConfig |
| base_model_prefix = "transformer" |
| module_class: nn.Module = None |
|
|
| def __init__( |
| self, |
| config: T5VaeConfig, |
| input_shape: Tuple[int] = (1, 1), |
| seed: int = 0, |
| dtype: jnp.dtype = jnp.float32, |
| **kwargs |
| ): |
| module = self.module_class(config=config, dtype=dtype, **kwargs) |
| super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype) |
|
|
| def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
| |
| input_ids = jnp.zeros(input_shape, dtype="i4") |
|
|
| attention_mask = jnp.ones_like(input_ids) |
| decoder_input_ids = jnp.ones_like(input_ids) |
| decoder_attention_mask = jnp.ones_like(input_ids) |
|
|
| params_rng, dropout_rng = jax.random.split(rng) |
| rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
| return self.module.init( |
| rngs, |
| input_ids, |
| attention_mask, |
| decoder_input_ids, |
| decoder_attention_mask, |
| )["params"] |
|
|
| def __call__( |
| self, |
| input_ids: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| decoder_input_ids: jnp.ndarray = None, |
| decoder_attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if decoder_input_ids is None: |
| raise ValueError( |
| "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." |
| ) |
|
|
| |
| if attention_mask is None: |
| attention_mask = jnp.ones_like(input_ids) |
|
|
| |
| if decoder_attention_mask is None: |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
| |
| rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
| return self.module.apply( |
| {"params": params or self.params}, |
| input_ids=jnp.array(input_ids, dtype="i4"), |
| attention_mask=jnp.array(attention_mask, dtype="i4"), |
| decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=not train, |
| rngs=rngs, |
| ) |
|
|
| def init_cache(self, batch_size, max_length, latent_codes): |
| r""" |
| Args: |
| batch_size (:obj:`int`): |
| batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
| max_length (:obj:`int`): |
| maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
| cache. |
| latent_codes (:obj:`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): |
| ``latent_codes`` consists of compressed hidden-states at the output of the last layer of the encoder. |
| Used in the cross-attention of the decoder. |
| """ |
| |
| decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
| def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs): |
| vae_decoder_module = module._get_vae_decoder_module() |
| decoder_module = module._get_decoder_module() |
| return decoder_module( |
| decoder_input_ids, |
| decoder_attention_mask, |
| encoder_hidden_states=vae_decoder_module(latent_codes), |
| **kwargs, |
| ) |
|
|
| init_variables = self.module.init( |
| jax.random.PRNGKey(0), |
| decoder_input_ids=decoder_input_ids, |
| latent_codes=latent_codes, |
| decoder_attention_mask=decoder_attention_mask, |
| init_cache=True, |
| method=_decoder_forward, |
| ) |
| return unfreeze(init_variables["cache"]) |
|
|
| def encode( |
| self, |
| input_ids: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| raise NotImplementedError() |
|
|
| def decode( |
| self, |
| decoder_input_ids, |
| latent_codes, |
| encoder_attention_mask: Optional[jnp.ndarray] = None, |
| decoder_attention_mask: Optional[jnp.ndarray] = None, |
| past_key_values: dict = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| raise NotImplementedError() |
|
|
|
|
| class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel): |
| module_class = FlaxT5VaeForAutoencodingModule |
|
|
| def __call__( |
| self, |
| input_ids: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| decoder_input_ids=None, |
| decoder_attention_mask=None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| ''' |
| Adapted from `FlaxT5PreTrainedModel` |
| ''' |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if decoder_input_ids is None: |
| raise ValueError( |
| "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed here." |
| ) |
|
|
| |
| if attention_mask is None: |
| attention_mask = jnp.ones_like(input_ids) |
|
|
| |
| if decoder_attention_mask is None: |
| decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
|
|
| |
| rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
| return self.module.apply( |
| {"params": params or self.params}, |
| input_ids=jnp.array(input_ids, dtype="i4"), |
| attention_mask=jnp.array(attention_mask, dtype="i4"), |
| decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=not train, |
| rngs=rngs, |
| ) |
|
|
| def encode( |
| self, |
| input_ids: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if attention_mask is None: |
| attention_mask = jnp.ones_like(input_ids) |
|
|
| |
| rngs = {} |
| if dropout_rng is not None: |
| rngs["dropout"] = dropout_rng |
|
|
| def _encoder_forward(module, input_ids, attention_mask, **kwargs): |
| encode_module = module._get_encoder_module() |
| vae_encoder_module = module._get_vae_encoder_module() |
| return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0]) |
|
|
| return self.module.apply( |
| {"params": params or self.params}, |
| input_ids=jnp.array(input_ids, dtype="i4"), |
| attention_mask=jnp.array(attention_mask, dtype="i4"), |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=not train, |
| rngs=rngs, |
| method=_encoder_forward, |
| ) |
|
|
| def decode( |
| self, |
| decoder_input_ids, |
| latent_codes, |
| encoder_attention_mask: Optional[jnp.ndarray] = None, |
| decoder_attention_mask: Optional[jnp.ndarray] = None, |
| past_key_values: dict = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| train: bool = False, |
| params: dict = None, |
| dropout_rng: PRNGKey = None, |
| ): |
| r""" |
| Returns: |
| |
| Example:: |
| |
| >>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small') |
| >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') |
| |
| >>> text = "My friends are cool but they eat too many carbs." |
| >>> inputs = tokenizer(text, max_length=512, return_tensors='jax') |
| >>> latent_codes = model.encode(**inputs) |
| |
| >>> decoder_start_token_id = model.config.decoder_start_token_id |
| >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id |
| |
| >>> outputs = model.decode(decoder_input_ids, latent_codes) |
| >>> last_decoder_hidden_states = outputs.last_hidden_state |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if encoder_attention_mask is None: |
| batch_size, sequence_length = latent_codes.shape[:2] |
| encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
| batch_size, sequence_length = decoder_input_ids.shape |
| if decoder_attention_mask is None: |
| decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
| |
| rngs = {} |
| if dropout_rng is not None: |
| rngs["dropout"] = dropout_rng |
|
|
| inputs = {"params": params or self.params} |
|
|
| |
| |
| |
| if past_key_values: |
| inputs["cache"] = past_key_values |
| mutable = ["cache"] |
| else: |
| mutable = False |
|
|
| def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs): |
| vae_decoder_module = module._get_vae_decoder_module() |
| decoder_module = module._get_decoder_module() |
| decoder_outputs = decoder_module( |
| decoder_input_ids, |
| decoder_attention_mask, |
| encoder_hidden_states=vae_decoder_module(latent_codes), |
| **kwargs, |
| ) |
| sequence_output = decoder_outputs[0] |
|
|
| if self.config.tie_word_embeddings: |
| |
| |
| sequence_output = sequence_output * (self.config.d_model ** -0.5) |
|
|
| if self.config.tie_word_embeddings: |
| shared_embedding = module.t5.shared.variables["params"]["embedding"] |
| lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output) |
| else: |
| lm_logits = module.t5.lm_head(sequence_output) |
|
|
| return lm_logits, decoder_outputs |
|
|
| outputs = self.module.apply( |
| inputs, |
| decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| latent_codes=latent_codes, |
| decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| deterministic=not train, |
| rngs=rngs, |
| mutable=mutable, |
| method=_decoder_forward, |
| ) |
|
|
| if past_key_values is None: |
| lm_logits, decoder_outputs = outputs |
| else: |
| (lm_logits, decoder_outputs), past = outputs |
|
|
| if return_dict: |
| outputs = FlaxCausalLMOutputWithCrossAttentions( |
| logits=lm_logits, |
| hidden_states=decoder_outputs.hidden_states, |
| attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| ) |
| else: |
| outputs = (lm_logits,) + decoder_outputs[1:] |
|
|
| |
| if past_key_values is not None and return_dict: |
| outputs["past_key_values"] = unfreeze(past["cache"]) |
| return outputs |
| elif past_key_values is not None and not return_dict: |
| outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
| return outputs |
|
|
| def prepare_inputs_for_generation( |
| self, |
| decoder_input_ids, |
| max_length, |
| attention_mask: Optional[jnp.DeviceArray] = None, |
| decoder_attention_mask: Optional[jnp.DeviceArray] = None, |
| latent_codes=None, |
| **kwargs |
| ): |
| |
| batch_size, seq_length = decoder_input_ids.shape |
|
|
| past_key_values = self.init_cache(batch_size, max_length, latent_codes) |
| |
| |
| |
| extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
| if decoder_attention_mask is not None: |
| extended_attention_mask = jax.lax.dynamic_update_slice( |
| extended_attention_mask, decoder_attention_mask, (0, 0) |
| ) |
|
|
| return { |
| "past_key_values": past_key_values, |
| "latent_codes": latent_codes, |
| "encoder_attention_mask": attention_mask, |
| "decoder_attention_mask": extended_attention_mask, |
| } |
|
|
| def update_inputs_for_generation(self, model_outputs, model_kwargs): |
| model_kwargs["past_key_values"] = model_outputs.past_key_values |
| return model_kwargs |
|
|