| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Tests for vit vqvae model.""" |
| from absl.testing import absltest |
|
|
| from big_vision.models.proj.uvim import vit |
| import jax |
| import jax.numpy as jnp |
| import ml_collections |
|
|
|
|
| class ViTVQVAEModelTest(absltest.TestCase): |
|
|
| def test_model(self): |
| model_config = ml_collections.ConfigDict({ |
| "input_size": (32, 32), |
| "code_len": 4, |
| "width": 16, |
| "mlp_dim": 64, |
| "num_heads": 4, |
| "enc_depth": 1, |
| "dec_depth": 1, |
| "with_encoder_ctx": True, |
| "with_decoder_ctx": True, |
| "statistics_axis_name": None, |
| "inputs": { |
| "in1": (10, 3), |
| "in2": (25,), |
| }, |
| "outputs": { |
| "out1": (5,), |
| "out2": (20,), |
| }, |
| }) |
|
|
| model = vit.Model(**model_config) |
| batch_size = 4 |
| seq_len = (32 // 8) ** 2 |
| x = { |
| "in1": jnp.zeros((batch_size, seq_len, 10, 3)), |
| "in2": jnp.zeros((batch_size, seq_len, 25)), |
| } |
| ctx_image = jnp.zeros((batch_size,) + model_config.input_size + (3,)) |
| init_rngs = { |
| "params": jax.random.PRNGKey(0), |
| "state": jax.random.PRNGKey(1), |
| } |
| params = model.init(init_rngs, x, ctx=ctx_image) |
| self.assertEqual(params.keys(), set(["params", "state"])) |
|
|
| apply_rngs = { |
| "dropout": jax.random.PRNGKey(0), |
| "vqvae": jax.random.PRNGKey(0), |
| } |
| (logits, _), params = model.apply( |
| params, x, ctx=ctx_image, train=True, update_dict=True, |
| rngs=apply_rngs, mutable=["state"]) |
| self.assertEqual(logits.keys(), set(["out1", "out2"])) |
| self.assertEqual(logits["out1"].shape, (batch_size, seq_len, 5)) |
| self.assertEqual(logits["out2"].shape, (batch_size, seq_len, 20)) |
|
|
|
|
| if __name__ == "__main__": |
| absltest.main() |
|
|