| import tempfile |
| import unittest |
| import random |
|
|
| import numpy as np |
| import optax |
| from flax.training.common_utils import onehot |
|
|
| from transformers import is_flax_available, AutoTokenizer |
| from transformers.models.t5.modeling_flax_t5 import shift_tokens_right |
| from transformers.testing_utils import require_flax, require_tokenizers, slow |
|
|
| from tests.test_configuration_common import ConfigTester |
| from tests.test_generation_flax_utils import FlaxGenerationTesterMixin |
| from tests.test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor |
|
|
|
|
| if is_flax_available(): |
| import os |
|
|
| |
| |
| |
| os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" |
|
|
| import jax |
| import jax.numpy as jnp |
| from flax.core.frozen_dict import unfreeze |
| from flax.traverse_util import flatten_dict |
| from transformers import FLAX_MODEL_MAPPING |
| from model.t5_vae import FlaxT5VaeForAutoencoding, T5VaeConfig |
|
|
|
|
| class FlaxVaeModelTester: |
| def __init__( |
| self, |
| parent, |
| vocab_size=99, |
| batch_size=13, |
| seq_length=7, |
| latent_token_size=10, |
| n_latent_tokens=3, |
| |
| is_training=True, |
| use_attention_mask=True, |
| use_labels=True, |
| hidden_size=32, |
| num_hidden_layers=5, |
| num_attention_heads=4, |
| d_ff=37, |
| relative_attention_num_buckets=8, |
| dropout_rate=0.1, |
| initializer_factor=0.002, |
| eos_token_id=1, |
| pad_token_id=0, |
| decoder_start_token_id=0, |
| scope=None, |
| decoder_layers=None, |
| ): |
|
|
| self.parent = parent |
| self.batch_size = batch_size |
| self.latent_token_size = latent_token_size |
| self.n_latent_tokens = n_latent_tokens |
| |
| self.seq_length = seq_length |
| self.is_training = is_training |
| self.use_attention_mask = use_attention_mask |
| self.use_labels = use_labels |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.d_ff = d_ff |
| self.relative_attention_num_buckets = relative_attention_num_buckets |
| self.dropout_rate = dropout_rate |
| self.initializer_factor = initializer_factor |
| self.eos_token_id = eos_token_id |
| self.pad_token_id = pad_token_id |
| self.decoder_start_token_id = decoder_start_token_id |
| self.scope = None |
| self.decoder_layers = decoder_layers |
|
|
| def prepare_config_and_inputs(self): |
| input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) |
| decoder_input_ids = shift_tokens_right(input_ids, self.pad_token_id, self.pad_token_id) |
|
|
| attention_mask = None |
| decoder_attention_mask = None |
| if self.use_attention_mask: |
| attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) |
| decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) |
|
|
| config = T5VaeConfig( |
| latent_token_size=self.latent_token_size, |
| n_latent_tokens=self.n_latent_tokens, |
| vocab_size=self.vocab_size, |
| d_model=self.hidden_size, |
| block_size=self.seq_length, |
| d_ff=self.d_ff, |
| d_kv=self.hidden_size // self.num_attention_heads, |
| num_layers=self.num_hidden_layers, |
| num_decoder_layers=self.decoder_layers, |
| num_heads=self.num_attention_heads, |
| relative_attention_num_buckets=self.relative_attention_num_buckets, |
| dropout_rate=self.dropout_rate, |
| initializer_factor=self.initializer_factor, |
| eos_token_id=self.eos_token_id, |
| bos_token_id=self.pad_token_id, |
| pad_token_id=self.pad_token_id, |
| decoder_start_token_id=self.decoder_start_token_id, |
| ) |
|
|
| return ( |
| config, |
| input_ids, |
| decoder_input_ids, |
| attention_mask, |
| decoder_attention_mask, |
| ) |
|
|
| def create_and_check_model( |
| self, |
| config, |
| input_ids, |
| decoder_input_ids, |
| attention_mask, |
| decoder_attention_mask, |
| ): |
| model = FlaxT5VaeForAutoencoding(config=config) |
| result = model( |
| input_ids=input_ids, |
| decoder_input_ids=decoder_input_ids, |
| attention_mask=attention_mask, |
| decoder_attention_mask=decoder_attention_mask, |
| ) |
| result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True) |
| decoder_output = result.last_hidden_state |
| encoder_output = result.encoder_last_hidden_state |
|
|
| self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size)) |
| self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size)) |
|
|
| def check_use_cache_forward_with_attn_mask( |
| self, |
| model_class_name, |
| config, |
| input_ids, |
| decoder_input_ids, |
| attention_mask, |
| decoder_attention_mask, |
| ): |
| max_decoder_length = 20 |
| model = model_class_name(config) |
|
|
| latent_codes = model.encode(input_ids) |
|
|
| |
| decoder_attention_mask = jnp.ones_like(decoder_attention_mask) |
|
|
| decoder_attention_mask_cache = jnp.concatenate( |
| [ |
| decoder_attention_mask, |
| jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), |
| ], |
| axis=-1, |
| ) |
|
|
| past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) |
|
|
| outputs_cache = model.decode( |
| decoder_input_ids[:, :-1], |
| latent_codes, |
| decoder_attention_mask=decoder_attention_mask_cache, |
| past_key_values=past_key_values, |
| ) |
| outputs_cache_next = model.decode( |
| decoder_input_ids[:, -1:], |
| latent_codes, |
| past_key_values=outputs_cache.past_key_values, |
| decoder_attention_mask=decoder_attention_mask_cache, |
| ) |
|
|
| outputs = model.decode(decoder_input_ids, latent_codes, decoder_attention_mask=decoder_attention_mask) |
|
|
| diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) |
| self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") |
|
|
| def prepare_config_and_inputs_for_common(self): |
| config_and_inputs = self.prepare_config_and_inputs() |
| ( |
| config, |
| input_ids, |
| decoder_input_ids, |
| attention_mask, |
| decoder_attention_mask, |
| ) = config_and_inputs |
|
|
| inputs_dict = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "decoder_input_ids": decoder_input_ids, |
| "decoder_attention_mask": decoder_attention_mask, |
| } |
| return config, inputs_dict |
|
|
|
|
| @require_flax |
| class FlaxT5VaeModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): |
|
|
| all_model_classes = (FlaxT5VaeForAutoencoding,) if is_flax_available() else () |
| is_encoder_decoder = True |
|
|
| def setUp(self): |
| self.model_tester = FlaxVaeModelTester(self) |
| self.config_tester = ConfigTester(self, config_class=T5VaeConfig, d_model=37) |
|
|
| def test_config(self): |
| self.config_tester.run_common_tests() |
|
|
| def test_model(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| self.model_tester.create_and_check_model(*config_and_inputs) |
|
|
| def test_model_v1_1(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| |
| config = config_and_inputs[0] |
| config.tie_word_embeddings = False |
| config.feed_forward_proj = "gated-gelu" |
| self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) |
|
|
| def test_use_cache_forward_with_attn_mask(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| for model_class in self.all_model_classes: |
| self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs) |
|
|
| def test_encode(self): |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
| for model_class in self.all_model_classes: |
| with self.subTest(model_class.__name__): |
| prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) |
| model = model_class(config) |
|
|
| @jax.jit |
| def encode_jitted(input_ids, attention_mask=None, **kwargs): |
| return model.encode(input_ids=input_ids, attention_mask=attention_mask) |
|
|
| with self.subTest("JIT Enabled"): |
| jitted_outputs = encode_jitted(**prepared_inputs_dict) |
|
|
| with self.subTest("JIT Disabled"): |
| with jax.disable_jit(): |
| outputs = encode_jitted(**prepared_inputs_dict) |
|
|
| self.assertEqual(outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size)) |
| self.assertEqual(jitted_outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size)) |
|
|
| self.assertEqual(len(outputs), len(jitted_outputs)) |
| for jitted_output, output in zip(jitted_outputs, outputs): |
| self.assertEqual(jitted_output.shape, output.shape) |
|
|
| def test_decode(self): |
| config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
| for model_class in self.all_model_classes: |
| with self.subTest(model_class.__name__): |
| model = model_class(config) |
| latent_codes = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) |
|
|
| prepared_inputs_dict = { |
| "decoder_input_ids": inputs_dict["decoder_input_ids"], |
| "decoder_attention_mask": inputs_dict["decoder_attention_mask"], |
| "latent_codes": latent_codes, |
| } |
|
|
| @jax.jit |
| def decode_jitted(decoder_input_ids, decoder_attention_mask, latent_codes): |
| return model.decode( |
| decoder_input_ids=decoder_input_ids, |
| latent_codes=latent_codes, |
| decoder_attention_mask=decoder_attention_mask, |
| ) |
|
|
| with self.subTest("JIT Enabled"): |
| jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() |
|
|
| with self.subTest("JIT Disabled"): |
| with jax.disable_jit(): |
| outputs = decode_jitted(**prepared_inputs_dict).to_tuple() |
|
|
| self.assertEqual(len(outputs), len(jitted_outputs)) |
| for jitted_output, output in zip(jitted_outputs, outputs): |
| self.assertEqual(jitted_output.shape, output.shape) |
|
|
| def test_save_and_load(self): |
| config, _ = self.model_tester.prepare_config_and_inputs_for_common() |
|
|
| model = FlaxT5VaeForAutoencoding(config) |
| model_params = flatten_dict(unfreeze(model.params)) |
|
|
| |
| with tempfile.TemporaryDirectory() as tmpdirname: |
| model.save_pretrained(tmpdirname) |
| head_model = FlaxT5VaeForAutoencoding.from_pretrained(tmpdirname) |
| new_params = flatten_dict(unfreeze(head_model.params)) |
|
|
| for key in new_params.keys(): |
| max_diff = (model_params[key] - new_params[key]).sum().item() |
| self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") |
|
|
|
|
| |
| def compute_kernel(x, y): |
| x_size = x.shape[0] |
| y_size = y.shape[0] |
| dim = x.shape[1] |
| tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1) |
| tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0) |
| return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0) |
|
|
| def compute_mmd(x, y): |
| x_kernel = compute_kernel(x, x) |
| y_kernel = compute_kernel(y, y) |
| xy_kernel = compute_kernel(x, y) |
| return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel) |
|
|
| def regulariser_loss(latent_codes, rng): |
| true_samples = jax.random.normal(rng, latent_codes.shape) |
| return compute_mmd(true_samples, latent_codes) |
|
|
| def loss_fn(logits, labels, latent_codes, regulariser_rng): |
| shift_logits = logits[..., :-1, :] |
| loss = optax.softmax_cross_entropy(shift_logits, onehot(labels, logits.shape[-1])) |
| reg_loss = regulariser_loss(latent_codes.reshape(-1, latent_codes.shape[-1]), regulariser_rng) |
| return loss.mean() + reg_loss.mean() |
| |
|
|
|
|
| @require_tokenizers |
| @require_flax |
| class FlaxT5VaeModelIntegrationTests(unittest.TestCase): |
| @slow |
| def test_training_step(self): |
| """ |
| For comparision run: |
| >>> import t5 # pip install t5==0.7.1 |
| >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary |
| |
| >>> path_to_mtf_small_t5_checkpoint = '<fill_in>' |
| >>> path_to_mtf_small_spm_model_path = '<fill_in>' |
| >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None) |
| >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) |
| >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) |
| """ |
| config = T5VaeConfig(t5_model_name_or_path="t5-small", n_latent_tokens=2) |
| tokenizer = AutoTokenizer.from_pretrained("t5-small") |
| vocab_size = len(tokenizer) |
| config.t5.vocab_size = vocab_size |
| config.vocab_size = vocab_size |
| model = FlaxT5VaeForAutoencoding( |
| config, seed=42, dtype=jnp.float32 |
| ) |
|
|
| input_ids = tokenizer("Hello there my name is fraser.", return_tensors="np").input_ids |
| labels = input_ids.copy() |
|
|
| |
| pad_input_ids = jnp.concatenate((input_ids, config.pad_token_id * jnp.ones((1, 1), dtype=jnp.int32)), axis=1) |
| decoder_input_ids = shift_tokens_right(pad_input_ids, config.pad_token_id, config.decoder_start_token_id) |
|
|
| outputs = model(input_ids, decoder_input_ids=decoder_input_ids) |
| logits, latent_codes = outputs[0], outputs[1] |
| loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42)) |
| import pdb |
| pdb.set_trace() |
| pass |
|
|