transformer-vae / tests /test_t5_vae.py
Fraser's picture
working saving & loading
7633929
import tempfile
import unittest
import random
import numpy as np
import optax
from flax.training.common_utils import onehot
from transformers import is_flax_available, AutoTokenizer
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from transformers.testing_utils import require_flax, require_tokenizers, slow
from tests.test_configuration_common import ConfigTester
from tests.test_generation_flax_utils import FlaxGenerationTesterMixin
from tests.test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import os
# The slow tests are often failing with OOM error on GPU
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_MAPPING
from model.t5_vae import FlaxT5VaeForAutoencoding, T5VaeConfig
class FlaxVaeModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
seq_length=7,
latent_token_size=10,
n_latent_tokens=3,
# For common tests
is_training=True,
use_attention_mask=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
d_ff=37,
relative_attention_num_buckets=8,
dropout_rate=0.1,
initializer_factor=0.002,
eos_token_id=1,
pad_token_id=0,
decoder_start_token_id=0,
scope=None,
decoder_layers=None,
):
self.parent = parent
self.batch_size = batch_size
self.latent_token_size = latent_token_size
self.n_latent_tokens = n_latent_tokens
# For common tests
self.seq_length = seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.d_ff = d_ff
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.initializer_factor = initializer_factor
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.scope = None
self.decoder_layers = decoder_layers
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
decoder_input_ids = shift_tokens_right(input_ids, self.pad_token_id, self.pad_token_id)
attention_mask = None
decoder_attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
config = T5VaeConfig(
latent_token_size=self.latent_token_size,
n_latent_tokens=self.n_latent_tokens,
vocab_size=self.vocab_size,
d_model=self.hidden_size,
block_size=self.seq_length,
d_ff=self.d_ff,
d_kv=self.hidden_size // self.num_attention_heads,
num_layers=self.num_hidden_layers,
num_decoder_layers=self.decoder_layers,
num_heads=self.num_attention_heads,
relative_attention_num_buckets=self.relative_attention_num_buckets,
dropout_rate=self.dropout_rate,
initializer_factor=self.initializer_factor,
eos_token_id=self.eos_token_id,
bos_token_id=self.pad_token_id,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
)
return (
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
)
def create_and_check_model(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
):
model = FlaxT5VaeForAutoencoding(config=config)
result = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
decoder_output = result.last_hidden_state
encoder_output = result.encoder_last_hidden_state
self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
def check_use_cache_forward_with_attn_mask(
self,
model_class_name,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
):
max_decoder_length = 20
model = model_class_name(config)
latent_codes = model.encode(input_ids)
# prevent fully zero'd out attention mask
decoder_attention_mask = jnp.ones_like(decoder_attention_mask)
decoder_attention_mask_cache = jnp.concatenate(
[
decoder_attention_mask,
jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
outputs_cache = model.decode(
decoder_input_ids[:, :-1],
latent_codes,
decoder_attention_mask=decoder_attention_mask_cache,
past_key_values=past_key_values,
)
outputs_cache_next = model.decode(
decoder_input_ids[:, -1:],
latent_codes,
past_key_values=outputs_cache.past_key_values,
decoder_attention_mask=decoder_attention_mask_cache,
)
outputs = model.decode(decoder_input_ids, latent_codes, decoder_attention_mask=decoder_attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return config, inputs_dict
@require_flax
class FlaxT5VaeModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
all_model_classes = (FlaxT5VaeForAutoencoding,) if is_flax_available() else ()
is_encoder_decoder = True
def setUp(self):
self.model_tester = FlaxVaeModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5VaeConfig, d_model=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_use_cache_forward_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs)
def test_encode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
def encode_jitted(input_ids, attention_mask=None, **kwargs):
return model.encode(input_ids=input_ids, attention_mask=attention_mask)
with self.subTest("JIT Enabled"):
jitted_outputs = encode_jitted(**prepared_inputs_dict)
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = encode_jitted(**prepared_inputs_dict)
self.assertEqual(outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size))
self.assertEqual(jitted_outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size))
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_decode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
model = model_class(config)
latent_codes = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
prepared_inputs_dict = {
"decoder_input_ids": inputs_dict["decoder_input_ids"],
"decoder_attention_mask": inputs_dict["decoder_attention_mask"],
"latent_codes": latent_codes,
}
@jax.jit
def decode_jitted(decoder_input_ids, decoder_attention_mask, latent_codes):
return model.decode(
decoder_input_ids=decoder_input_ids,
latent_codes=latent_codes,
decoder_attention_mask=decoder_attention_mask,
)
with self.subTest("JIT Enabled"):
jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
def test_save_and_load(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = FlaxT5VaeForAutoencoding(config)
model_params = flatten_dict(unfreeze(model.params))
# check that all base model weights are loaded correctly
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
head_model = FlaxT5VaeForAutoencoding.from_pretrained(tmpdirname)
new_params = flatten_dict(unfreeze(head_model.params))
for key in new_params.keys():
max_diff = (model_params[key] - new_params[key]).sum().item()
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
## Copied training methdos
def compute_kernel(x, y):
x_size = x.shape[0]
y_size = y.shape[0]
dim = x.shape[1]
tiled_x = jnp.repeat(jnp.reshape(x, (x_size, 1, dim)), y_size, axis=1)
tiled_y = jnp.repeat(jnp.reshape(y, (1, y_size, dim)), x_size, axis=0)
return jnp.exp(-jnp.mean((tiled_x - tiled_y) ** 2, axis=2) / dim * 1.0)
def compute_mmd(x, y):
x_kernel = compute_kernel(x, x)
y_kernel = compute_kernel(y, y)
xy_kernel = compute_kernel(x, y)
return jnp.mean(x_kernel) + jnp.mean(y_kernel) - 2 * jnp.mean(xy_kernel)
def regulariser_loss(latent_codes, rng):
true_samples = jax.random.normal(rng, latent_codes.shape)
return compute_mmd(true_samples, latent_codes)
def loss_fn(logits, labels, latent_codes, regulariser_rng):
shift_logits = logits[..., :-1, :]
loss = optax.softmax_cross_entropy(shift_logits, onehot(labels, logits.shape[-1]))
reg_loss = regulariser_loss(latent_codes.reshape(-1, latent_codes.shape[-1]), regulariser_rng)
return loss.mean() + reg_loss.mean()
##
@require_tokenizers
@require_flax
class FlaxT5VaeModelIntegrationTests(unittest.TestCase):
@slow
def test_training_step(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
config = T5VaeConfig(t5_model_name_or_path="t5-small", n_latent_tokens=2)
tokenizer = AutoTokenizer.from_pretrained("t5-small")
vocab_size = len(tokenizer)
config.t5.vocab_size = vocab_size
config.vocab_size = vocab_size
model = FlaxT5VaeForAutoencoding(
config, seed=42, dtype=jnp.float32
)
input_ids = tokenizer("Hello there my name is fraser.", return_tensors="np").input_ids
labels = input_ids.copy()
# pad right so not loosing a token on shift
pad_input_ids = jnp.concatenate((input_ids, config.pad_token_id * jnp.ones((1, 1), dtype=jnp.int32)), axis=1)
decoder_input_ids = shift_tokens_right(pad_input_ids, config.pad_token_id, config.decoder_start_token_id)
outputs = model(input_ids, decoder_input_ids=decoder_input_ids)
logits, latent_codes = outputs[0], outputs[1]
loss = loss_fn(logits, labels, latent_codes, jax.random.PRNGKey(42))
import pdb
pdb.set_trace()
pass