| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Unit tests for few-shot utils.""" |
|
|
| from absl.testing import absltest |
| from big_vision.evaluators import fewshot as bv_fewshot |
| import jax |
| from jax import random |
| from scenic.train_lib.transfer import fewshot_utils |
|
|
| jax.config.update('jax_threefry_partitionable', False) |
|
|
|
|
| def big_vision_linear_regression(x, y, x_test, y_test, l2_reg, num_classes): |
| """Computes fewshot regression with eigenvalue solver in big_vision.""" |
| |
| cache = bv_fewshot._precompute_cache(x, y, num_classes) |
| accuracy = bv_fewshot._eig_fewshot_acc_fn(cache, x_test, y_test, l2_reg) |
| |
| return accuracy |
|
|
|
|
| class LinearRegressionTest(absltest.TestCase): |
| """Tests linear regression used in few-shot evaluation.""" |
|
|
| def test_linear_regression(self): |
| """Test linear regression.""" |
| |
| num_points = 512 |
| dim = 16 |
| num_classes = 5 |
| l2_regs = [1.0, 2.0, 8.0, 0.0] |
| rng = random.PRNGKey(0) |
|
|
| x = random.normal(rng, shape=(num_points, dim)) |
| x_test = random.normal(rng, shape=(num_points, dim)) |
| y = random.randint(rng, shape=(num_points,), minval=0, maxval=num_classes) |
| y_test = random.randint( |
| rng, shape=(num_points,), minval=0, maxval=num_classes) |
|
|
| for l2_reg in l2_regs: |
| |
| accuracy = fewshot_utils._fewshot_acc_fn( |
| x, |
| y, |
| x_test, |
| y_test, |
| l2_reg, |
| num_classes, |
| target_is_one_hot=False) |
|
|
| |
| expected_accuracy = big_vision_linear_regression(x, y, x_test, y_test, |
| l2_reg, num_classes) |
| self.assertGreater(accuracy, 0) |
| self.assertLess(accuracy, 1) |
| self.assertAlmostEqual(accuracy, expected_accuracy, delta=1e-6) |
|
|
| |
| y_one_hot = jax.nn.one_hot(y, num_classes) |
| y_test_one_hot = jax.nn.one_hot(y_test, num_classes) |
|
|
| accuracy_one_hot = fewshot_utils._fewshot_acc_fn( |
| x, |
| y_one_hot, |
| x_test, |
| y_test_one_hot, |
| l2_reg, |
| num_classes, |
| target_is_one_hot=True) |
| self.assertEqual(accuracy, accuracy_one_hot) |
|
|
|
|
| if __name__ == '__main__': |
| absltest.main() |
|
|