|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for ops_text.""" |
|
|
|
|
|
import copy |
|
|
|
|
|
from absl.testing import parameterized |
|
|
import big_vision.pp.ops_text as pp |
|
|
from big_vision.pp.registry import Registry |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
class PyToTfWrapper: |
|
|
"""Allows to use `to_{int,str}_tf()` via `to_{int,str}()`.""" |
|
|
|
|
|
def __init__(self, tok): |
|
|
self.tok = tok |
|
|
self.bos_token = tok.bos_token |
|
|
self.eos_token = tok.eos_token |
|
|
self.vocab_size = tok.vocab_size |
|
|
|
|
|
def to_int(self, text, *, bos=False, eos=False): |
|
|
ret = self.tok.to_int_tf_op(text, bos=bos, eos=eos) |
|
|
if isinstance(ret, tf.RaggedTensor): |
|
|
return [t.numpy().tolist() for t in ret] |
|
|
return ret.numpy().tolist() |
|
|
|
|
|
def to_str(self, tokens, stop_at_eos=True): |
|
|
ret = self.tok.to_str_tf_op( |
|
|
tf.ragged.constant(tokens), |
|
|
stop_at_eos=stop_at_eos, |
|
|
) |
|
|
if ret.ndim == 0: |
|
|
return ret.numpy().decode() |
|
|
return [t.numpy().decode() for t in ret] |
|
|
|
|
|
|
|
|
class PpOpsTest(tf.test.TestCase, parameterized.TestCase): |
|
|
|
|
|
def tfrun(self, ppfn, data): |
|
|
|
|
|
yield {k: np.array(v) for k, v in ppfn(copy.deepcopy(data)).items()} |
|
|
|
|
|
|
|
|
|
|
|
tfdata = tf.data.Dataset.from_tensors(copy.deepcopy(data)) |
|
|
for npdata in tfdata.map(ppfn).as_numpy_iterator(): |
|
|
yield npdata |
|
|
|
|
|
def testtok(self): |
|
|
|
|
|
return "test_model.model" |
|
|
|
|
|
def test_get_pp_clip_i1k_label_names(self): |
|
|
op = pp.get_pp_clip_i1k_label_names() |
|
|
labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist() |
|
|
self.assertAllEqual(labels, ["tench", "goldfish"]) |
|
|
|
|
|
def test_get_pp_i21k_label_names(self): |
|
|
op = pp.get_pp_i21k_label_names() |
|
|
labels = op({"label": tf.constant([0, 1])})["labels"].numpy().tolist() |
|
|
self.assertAllEqual(labels, ["organism", "benthos"]) |
|
|
|
|
|
@parameterized.parameters((b"Hello world ScAlAr!", b"hello world scalar!"), |
|
|
(["Decoded Array!"], ["decoded array!"]), |
|
|
([b"aA", "bB"], [b"aa", "bb"])) |
|
|
def test_get_lower(self, inputs, expected_output): |
|
|
op = pp.get_lower() |
|
|
out = op({"text": tf.constant(inputs)}) |
|
|
self.assertAllEqual(out["text"].numpy(), np.array(expected_output)) |
|
|
|
|
|
@parameterized.named_parameters( |
|
|
("py", False), |
|
|
("tf", True), |
|
|
) |
|
|
def test_sentencepiece_tokenizer(self, wrap_tok): |
|
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
|
if wrap_tok: |
|
|
tok = PyToTfWrapper(tok) |
|
|
self.assertEqual(tok.vocab_size, 1000) |
|
|
bos, eos = tok.bos_token, tok.eos_token |
|
|
self.assertEqual(bos, 1) |
|
|
self.assertEqual(eos, 2) |
|
|
|
|
|
|
|
|
self.assertEqual(tok.to_int("blah"), [80, 180, 60]) |
|
|
self.assertEqual(tok.to_int("blah", bos=True), [bos, 80, 180, 60]) |
|
|
self.assertEqual(tok.to_int("blah", eos=True), [80, 180, 60, eos]) |
|
|
self.assertEqual( |
|
|
tok.to_int("blah", bos=True, eos=True), [bos, 80, 180, 60, eos] |
|
|
) |
|
|
self.assertEqual( |
|
|
tok.to_int(["blah", "blah blah"]), |
|
|
[[80, 180, 60], [80, 180, 60, 80, 180, 60]], |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(tok.to_str([80, 180, 60]), "blah") |
|
|
self.assertEqual(tok.to_str([1, 80, 180, 60]), "blah") |
|
|
self.assertEqual(tok.to_str([80, 180, 60, 2]), "blah") |
|
|
self.assertEqual( |
|
|
tok.to_str([[80, 180, 60], [80, 180, 60, 80, 180, 60]]), |
|
|
["blah", "blah blah"], |
|
|
) |
|
|
|
|
|
def test_sentencepiece_tokenizer_tf_op_ndarray_input(self): |
|
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
|
bos, eos = tok.bos_token, tok.eos_token |
|
|
arr = np.array([[bos, 80, 180, 60, eos]] * 2, dtype=np.int32) |
|
|
self.assertEqual(tok.to_str_tf_op(arr).numpy().tolist(), [b"blah"] * 2) |
|
|
|
|
|
def test_sentencepiece_tokenizer_tokensets(self): |
|
|
tok = pp.SentencepieceTokenizer(self.testtok(), tokensets=["loc"]) |
|
|
self.assertEqual(tok.vocab_size, 2024) |
|
|
self.assertEqual( |
|
|
tok.to_int("blah<loc0000><loc1023>"), [80, 180, 60, 1000, 2023] |
|
|
) |
|
|
|
|
|
def test_sentencepiece_stop_at_eos(self): |
|
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
|
self.assertEqual(tok.to_str([80, 180, 60], stop_at_eos=False), "blah") |
|
|
eos = tok.eos_token |
|
|
self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=False), "blah") |
|
|
self.assertEqual(tok.to_str([80, eos, 180, 60], stop_at_eos=True), "b") |
|
|
self.assertEqual( |
|
|
tok.to_str([[80, eos, 180, 60], [80, 180, eos, 60]], stop_at_eos=True), |
|
|
["b", "bla"] |
|
|
) |
|
|
|
|
|
def test_sentencepiece_extra_tokens(self): |
|
|
tok = pp.SentencepieceTokenizer(self.testtok()) |
|
|
self.assertEqual(tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "blah") |
|
|
tok = pp.SentencepieceTokenizer( |
|
|
self.testtok(), tokensets=["sp_extra_tokens"] |
|
|
) |
|
|
self.assertEqual(tok.vocab_size, 1001) |
|
|
self.assertEqual( |
|
|
tok.to_str([1, 80, 180, 60, 2], stop_at_eos=False), "<s> blah</s>" |
|
|
) |
|
|
|
|
|
def test_strfmt(self): |
|
|
data = { |
|
|
"int": tf.constant(42, tf.uint8), |
|
|
"float": tf.constant(3.14, tf.float32), |
|
|
"vec": tf.range(3), |
|
|
"empty_str": tf.constant(""), |
|
|
"regex_problem1": tf.constant(r"no \replace pattern"), |
|
|
"regex_problem2": tf.constant(r"yes \1 pattern"), |
|
|
} |
|
|
for out in self.tfrun(pp.get_strfmt("Nothing"), data): |
|
|
self.assertEqual(out["text"], b"Nothing") |
|
|
for out in self.tfrun(pp.get_strfmt("{int}"), data): |
|
|
self.assertEqual(out["text"], b"42") |
|
|
for out in self.tfrun(pp.get_strfmt("A{int}"), data): |
|
|
self.assertEqual(out["text"], b"A42") |
|
|
for out in self.tfrun(pp.get_strfmt("{int}A"), data): |
|
|
self.assertEqual(out["text"], b"42A") |
|
|
for out in self.tfrun(pp.get_strfmt("{int}{int}"), data): |
|
|
self.assertEqual(out["text"], b"4242") |
|
|
for out in self.tfrun(pp.get_strfmt("A{int}A{int}A"), data): |
|
|
self.assertEqual(out["text"], b"A42A42A") |
|
|
for out in self.tfrun(pp.get_strfmt("A{float}A"), data): |
|
|
self.assertEqual(out["text"], b"A3.14A") |
|
|
for out in self.tfrun(pp.get_strfmt("A{float}A{int}"), data): |
|
|
self.assertEqual(out["text"], b"A3.14A42") |
|
|
for out in self.tfrun(pp.get_strfmt("A{vec}A"), data): |
|
|
self.assertEqual(out["text"], b"A[0 1 2]A") |
|
|
for out in self.tfrun(pp.get_strfmt("A{empty_str}A"), data): |
|
|
self.assertEqual(out["text"], b"AA") |
|
|
for out in self.tfrun(pp.get_strfmt("{empty_str}"), data): |
|
|
self.assertEqual(out["text"], b"") |
|
|
for out in self.tfrun(pp.get_strfmt("A{regex_problem1}A"), data): |
|
|
self.assertEqual(out["text"], br"Ano \replace patternA") |
|
|
for out in self.tfrun(pp.get_strfmt("A{regex_problem2}A"), data): |
|
|
self.assertEqual(out["text"], br"Ayes \1 patternA") |
|
|
|
|
|
|
|
|
@Registry.register("tokensets.sp_extra_tokens") |
|
|
def _get_sp_extra_tokens(): |
|
|
|
|
|
|
|
|
|
|
|
return ["<s>", "</s>", "<pad>"] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
tf.test.main() |
|
|
|