| import haiku | |
| import haiku as hk | |
| import jax.numpy as jnp | |
| import jax.random | |
| from vietTTS.nat.config import FLAGS | |
| from vietTTS.nat.model import DurationModel | |
| def test_duration(): | |
| net = DurationModel() | |
| p = jnp.zeros((2, 10), dtype=jnp.int32) | |
| l = jnp.zeros((2,), dtype=jnp.int32) | |
| o = net(p, l) | |
| assert o.shape == (2, 10, 1) | |