import tempfile import unittest import random import numpy as np import optax from flax.training.common_utils import onehot from transformers import is_flax_available, AutoTokenizer from transformers.models.t5.modeling_flax_t5 import shift_tokens_right from transformers.testing_utils import require_flax, require_tokenizers, slow from tests.test_configuration_common import ConfigTester from tests.test_generation_flax_utils import FlaxGenerationTesterMixin from tests.test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor if is_flax_available(): import os # The slow tests are often failing with OOM error on GPU # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" import jax import jax.numpy as jnp from flax.core.frozen_dict import unfreeze from flax.traverse_util import flatten_dict from transformers import FLAX_MODEL_MAPPING from model.t5_vae import FlaxT5VaeForAutoencoding, T5VaeConfig class FlaxVaeModelTester: def __init__( self, parent, vocab_size=99, batch_size=13, seq_length=7, latent_token_size=10, n_latent_tokens=3, # For common tests is_training=True, use_attention_mask=True, use_labels=True, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, d_ff=37, relative_attention_num_buckets=8, dropout_rate=0.1, initializer_factor=0.002, eos_token_id=1, pad_token_id=0, decoder_start_token_id=0, scope=None, decoder_layers=None, ): self.parent = parent self.batch_size = batch_size self.latent_token_size = latent_token_size self.n_latent_tokens = n_latent_tokens # For common tests self.seq_length = seq_length self.is_training = is_training self.use_attention_mask = use_attention_mask self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.d_ff = d_ff self.relative_attention_num_buckets = relative_attention_num_buckets self.dropout_rate = dropout_rate self.initializer_factor = initializer_factor self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self.decoder_start_token_id = decoder_start_token_id self.scope = None self.decoder_layers = decoder_layers def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) decoder_input_ids = shift_tokens_right(input_ids, self.pad_token_id, self.pad_token_id) attention_mask = None decoder_attention_mask = None if self.use_attention_mask: attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) config = T5VaeConfig( latent_token_size=self.latent_token_size, n_latent_tokens=self.n_latent_tokens, vocab_size=self.vocab_size, d_model=self.hidden_size, block_size=self.seq_length, d_ff=self.d_ff, d_kv=self.hidden_size // self.num_attention_heads, num_layers=self.num_hidden_layers, num_decoder_layers=self.decoder_layers, num_heads=self.num_attention_heads, relative_attention_num_buckets=self.relative_attention_num_buckets, dropout_rate=self.dropout_rate, initializer_factor=self.initializer_factor, eos_token_id=self.eos_token_id, bos_token_id=self.pad_token_id, pad_token_id=self.pad_token_id, decoder_start_token_id=self.decoder_start_token_id, ) return ( config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, ) def create_and_check_model( self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, ): model = FlaxT5VaeForAutoencoding(config=config) result = model( input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, ) result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True) decoder_output = result.last_hidden_state encoder_output = result.encoder_last_hidden_state self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size)) def check_use_cache_forward_with_attn_mask( self, model_class_name, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, ): max_decoder_length = 20 model = model_class_name(config) latent_codes = model.encode(input_ids) # prevent fully zero'd out attention mask decoder_attention_mask = jnp.ones_like(decoder_attention_mask) decoder_attention_mask_cache = jnp.concatenate( [ decoder_attention_mask, jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])), ], axis=-1, ) past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs) outputs_cache = model.decode( decoder_input_ids[:, :-1], latent_codes, decoder_attention_mask=decoder_attention_mask_cache, past_key_values=past_key_values, ) outputs_cache_next = model.decode( decoder_input_ids[:, -1:], latent_codes, past_key_values=outputs_cache.past_key_values, decoder_attention_mask=decoder_attention_mask_cache, ) outputs = model.decode(decoder_input_ids, latent_codes, decoder_attention_mask=decoder_attention_mask) diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))) self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}") def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, ) = config_and_inputs inputs_dict = { "input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids, "decoder_attention_mask": decoder_attention_mask, } return config, inputs_dict @require_flax class FlaxT5VaeModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase): all_model_classes = (FlaxT5VaeForAutoencoding,) if is_flax_available() else () is_encoder_decoder = True def setUp(self): self.model_tester = FlaxVaeModelTester(self) self.config_tester = ConfigTester(self, config_class=T5VaeConfig, d_model=37) def test_config(self): self.config_tester.run_common_tests() def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) def test_model_v1_1(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() # check that gated gelu feed forward and different word embeddings work config = config_and_inputs[0] config.tie_word_embeddings = False config.feed_forward_proj = "gated-gelu" self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) def test_use_cache_forward_with_attn_mask(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() for model_class in self.all_model_classes: self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs) def test_encode(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) model = model_class(config) @jax.jit def encode_jitted(input_ids, attention_mask=None, **kwargs): return model.encode(input_ids=input_ids, attention_mask=attention_mask) with self.subTest("JIT Enabled"): jitted_outputs = encode_jitted(**prepared_inputs_dict) with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = encode_jitted(**prepared_inputs_dict) self.assertEqual(outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size)) self.assertEqual(jitted_outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size)) self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) def test_decode(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): model = model_class(config) latent_codes = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) prepared_inputs_dict = { "decoder_input_ids": inputs_dict["decoder_input_ids"], "decoder_attention_mask": inputs_dict["decoder_attention_mask"], "latent_codes": latent_codes, } @jax.jit def decode_jitted(decoder_input_ids, decoder_attention_mask, latent_codes): return model.decode( decoder_input_ids=decoder_input_ids, latent_codes=latent_codes, decoder_attention_mask=decoder_attention_mask, ) with self.subTest("JIT Enabled"): jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = decode_jitted(**prepared_inputs_dict).to_tuple() self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) def test_save_and_load(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() model = FlaxT5VaeForAutoencoding(config) model_params = flatten_dict(unfreeze(model.params)) # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) head_model = FlaxT5VaeForAutoencoding.from_pretrained(tmpdirname) new_params = flatten_dict(unfreeze(head_model.params)) for key in new_params.keys(): max_diff = (model_params[key] - new_params[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") ## Copied training methdos def compute_kernel(x, y): x_size = x.shape[0] y_size = y.shape[0] dim = x.shape[1] tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1) tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0) return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0) def compute_mmd(x, y): x_kernel = compute_kernel(x, x) y_kernel = compute_kernel(y, y) xy_kernel = compute_kernel(x, y) return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel) def regulariser_loss(latent_codes, rng): true_samples = jax.random.normal(rng, latent_codes.shape) return compute_mmd(true_samples, latent_codes) def loss_fn(logits, labels, latent_codes, regulariser_rng): shift_logits = logits[..., :-1, :] loss = optax.softmax_cross_entropy(shift_logits, onehot(labels, logits.shape[-1])) reg_loss = regulariser_loss(latent_codes.reshape(-1, latent_codes.shape[-1]), regulariser_rng) return loss.mean() + reg_loss.mean() ## @require_tokenizers @require_flax class FlaxT5VaeModelIntegrationTests(unittest.TestCase): @slow def test_training_step(self): """ For comparision run: >>> import t5 # pip install t5==0.7.1 >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary >>> path_to_mtf_small_t5_checkpoint = '' >>> path_to_mtf_small_spm_model_path = '' >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None) >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ config = T5VaeConfig(t5_model_name_or_path="t5-small", n_latent_tokens=2) tokenizer = AutoTokenizer.from_pretrained("t5-small") vocab_size = len(tokenizer) config.t5.vocab_size = vocab_size config.vocab_size = vocab_size model = FlaxT5VaeForAutoencoding( config, seed=42, dtype=jnp.float32 ) input_ids = tokenizer("Hello there my name is fraser.", return_tensors="np").input_ids labels = input_ids.copy() # pad right so not loosing a token on shift pad_input_ids = jnp.concatenate((input_ids, config.pad_token_id * jnp.ones((1, 1), dtype=jnp.int32)), axis=1) decoder_input_ids = shift_tokens_right(pad_input_ids, config.pad_token_id, config.decoder_start_token_id) outputs = model(input_ids, decoder_input_ids=decoder_input_ids) logits, latent_codes = outputs[0], outputs[1] loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42)) import pdb pdb.set_trace() pass