| """ |
| Comprehensive tests for domainTokenizer core library. |
| 72 tests covering: schemas, field tokenizers, predefined schemas, |
| DomainTokenizerBuilder pipeline, and end-to-end HF encoding. |
| |
| Run: pytest tests/test_tokenizer.py -v |
| """ |
|
|
| import json |
| import math |
| import sys |
| from datetime import datetime |
|
|
| import numpy as np |
| import pytest |
|
|
| from domain_tokenizer.schema import DomainSchema, FieldSpec, FieldType, CALENDAR_FIELD_SIZES |
| from domain_tokenizer.tokenizers.field_tokenizers import ( |
| SignTokenizer, MagnitudeBucketTokenizer, DiscreteNumericalTokenizer, |
| CalendarTokenizer, CategoricalTokenizer, create_field_tokenizer, |
| ) |
| from domain_tokenizer.tokenizers.domain_tokenizer import DomainTokenizerBuilder |
| from domain_tokenizer.schemas.predefined import FINANCE_SCHEMA, ECOMMERCE_SCHEMA, HEALTHCARE_SCHEMA |
|
|
|
|
| class TestFieldSpec: |
| def test_sign_field(self): |
| spec = FieldSpec("amount_sign", FieldType.SIGN) |
| assert spec.token_count == 2 |
| assert spec.tokens_per_event == 1 |
|
|
| def test_numerical_continuous_field(self): |
| spec = FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, n_bins=21) |
| assert spec.token_count == 21 |
|
|
| def test_numerical_discrete_field(self): |
| spec = FieldSpec("quantity", FieldType.NUMERICAL_DISCRETE, max_value=10) |
| assert spec.token_count == 12 |
|
|
| def test_categorical_field(self): |
| spec = FieldSpec("event_type", FieldType.CATEGORICAL_FIXED, categories=["a", "b", "c"]) |
| assert spec.token_count == 4 |
|
|
| def test_temporal_field(self): |
| spec = FieldSpec("ts", FieldType.TEMPORAL, calendar_fields=["month", "dow", "dom", "hour"]) |
| assert spec.token_count == 74 |
|
|
| def test_text_field(self): |
| spec = FieldSpec("desc", FieldType.TEXT) |
| assert spec.token_count == 0 |
|
|
| def test_custom_prefix(self): |
| spec = FieldSpec("amount", FieldType.NUMERICAL_CONTINUOUS, prefix="PRICE") |
| assert spec.prefix == "PRICE" |
|
|
| def test_categorical_requires_categories(self): |
| with pytest.raises(ValueError): |
| FieldSpec("event", FieldType.CATEGORICAL_FIXED) |
|
|
| def test_discrete_requires_max_value(self): |
| with pytest.raises(ValueError): |
| FieldSpec("qty", FieldType.NUMERICAL_DISCRETE) |
|
|
|
|
| class TestDomainSchema: |
| def test_finance_token_count(self): |
| expected = 8 + 2 + 21 + 74 |
| assert FINANCE_SCHEMA.special_token_count == expected |
|
|
| def test_finance_fixed_tokens(self): |
| assert FINANCE_SCHEMA.fixed_tokens_per_event == 7 |
|
|
| def test_has_text_fields(self): |
| assert FINANCE_SCHEMA.has_text_fields is True |
|
|
| def test_text_field_names(self): |
| assert FINANCE_SCHEMA.text_field_names == ["description"] |
|
|
| def test_fittable_fields(self): |
| assert FINANCE_SCHEMA.fittable_field_names == ["amount"] |
|
|
| def test_get_field(self): |
| assert FINANCE_SCHEMA.get_field("amount").field_type == FieldType.NUMERICAL_CONTINUOUS |
|
|
| def test_get_field_missing(self): |
| assert FINANCE_SCHEMA.get_field("nonexistent") is None |
|
|
| def test_summary(self): |
| assert "finance" in FINANCE_SCHEMA.summary() |
|
|
|
|
| class TestSignTokenizer: |
| def test_positive(self): |
| assert SignTokenizer("S")(79.99) == "[S_POS]" |
|
|
| def test_negative(self): |
| assert SignTokenizer("S")(-50.0) == "[S_NEG]" |
|
|
| def test_zero(self): |
| assert SignTokenizer("S")(0.0) == "[S_POS]" |
|
|
| def test_none(self): |
| assert SignTokenizer("S")(None) == "[S_POS]" |
|
|
| def test_nan(self): |
| assert SignTokenizer("S")(float("nan")) == "[S_POS]" |
|
|
| def test_vocab_size(self): |
| assert SignTokenizer("S").vocab_size == 2 |
|
|
| def test_custom_labels(self): |
| tok = SignTokenizer("D", pos_label="CREDIT", neg_label="DEBIT") |
| assert tok(100) == "[D_CREDIT]" |
| assert tok(-100) == "[D_DEBIT]" |
|
|
|
|
| class TestMagnitudeBucketTokenizer: |
| def setup_method(self): |
| self.tok = MagnitudeBucketTokenizer("A", n_bins=5) |
| self.tok.fit(np.array([1, 2, 5, 10, 20, 50, 100, 200, 500, 1000])) |
|
|
| def test_low(self): |
| assert self.tok(1.0) == "[A_00]" |
|
|
| def test_high(self): |
| assert self.tok(1000.0) == "[A_04]" |
|
|
| def test_negative_abs(self): |
| assert self.tok(50.0) == self.tok(-50.0) |
|
|
| def test_none(self): |
| assert self.tok(None) == "[A_00]" |
|
|
| def test_nan(self): |
| assert self.tok(float("nan")) == "[A_00]" |
|
|
| def test_vocab(self): |
| assert self.tok.vocab_size == 5 |
|
|
| def test_not_fitted(self): |
| with pytest.raises(RuntimeError): |
| MagnitudeBucketTokenizer("X")(50.0) |
|
|
| def test_empty_fit(self): |
| with pytest.raises(ValueError): |
| MagnitudeBucketTokenizer("X").fit(np.array([])) |
|
|
| def test_nubank_21(self): |
| tok = MagnitudeBucketTokenizer("A", n_bins=21) |
| tok.fit(np.random.lognormal(3, 1, 10000)) |
| assert tok.vocab_size == 21 |
| for v in [0.01, 1.0, 100.0, 10000.0]: |
| assert tok(v) in tok.vocab |
|
|
| def test_serialization(self): |
| d = self.tok.to_dict() |
| tok2 = MagnitudeBucketTokenizer.from_dict(d) |
| assert tok2(50.0) == self.tok(50.0) |
|
|
|
|
| class TestDiscreteNumericalTokenizer: |
| def test_normal(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10)(3) == "[Q_03]" |
|
|
| def test_zero(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10)(0) == "[Q_00]" |
|
|
| def test_max(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10)(10) == "[Q_10]" |
|
|
| def test_overflow(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10)(15) == "[Q_OVER]" |
|
|
| def test_negative(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10)(-5) == "[Q_00]" |
|
|
| def test_none(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10)(None) == "[Q_00]" |
|
|
| def test_vocab(self): |
| assert DiscreteNumericalTokenizer("Q", max_value=10).vocab_size == 12 |
|
|
|
|
| class TestCalendarTokenizer: |
| def test_full(self): |
| tok = CalendarTokenizer("T", fields=["month", "dow", "dom", "hour"]) |
| tokens = tok(datetime(2025, 3, 15, 14, 30)) |
| assert len(tokens) == 4 |
| assert tokens[0] == "[T_MON_03]" |
| assert tokens[3] == "[T_HOUR_14]" |
|
|
| def test_string_input(self): |
| assert CalendarTokenizer("T", ["month"])("2025-03-15T14:30:00") == ["[T_MON_03]"] |
|
|
| def test_date_only(self): |
| tokens = CalendarTokenizer("T", ["month", "dow"])("2025-03-15") |
| assert tokens[0] == "[T_MON_03]" |
|
|
| def test_vocab_standard(self): |
| assert CalendarTokenizer("T", ["month", "dow", "dom", "hour"]).vocab_size == 74 |
|
|
| def test_subset(self): |
| assert CalendarTokenizer("T", ["month", "dow"]).vocab_size == 19 |
|
|
| def test_invalid(self): |
| with pytest.raises(ValueError): |
| CalendarTokenizer("T", ["invalid"]) |
|
|
| def test_quarter(self): |
| tok = CalendarTokenizer("T", ["quarter"]) |
| assert tok(datetime(2025, 1, 1)) == ["[T_Q1]"] |
| assert tok(datetime(2025, 10, 1)) == ["[T_Q4]"] |
|
|
|
|
| class TestCategoricalTokenizer: |
| def test_known(self): |
| assert CategoricalTokenizer("E", ["view", "buy"])("buy") == "[E_001]" |
|
|
| def test_unknown(self): |
| assert CategoricalTokenizer("E", ["view", "buy"])("refund") == "[E_UNK]" |
|
|
| def test_none(self): |
| assert CategoricalTokenizer("E", ["view"])( None) == "[E_UNK]" |
|
|
| def test_vocab_unk(self): |
| tok = CategoricalTokenizer("E", ["a", "b"]) |
| assert "[E_UNK]" in tok.vocab |
| assert tok.vocab_size == 3 |
|
|
| def test_decode(self): |
| tok = CategoricalTokenizer("E", ["view", "buy"]) |
| assert tok.decode_token("[E_000]") == "view" |
|
|
|
|
| class TestFactory: |
| def test_sign(self): |
| assert isinstance(create_field_tokenizer(FieldSpec("s", FieldType.SIGN)), SignTokenizer) |
|
|
| def test_magnitude(self): |
| assert isinstance(create_field_tokenizer(FieldSpec("a", FieldType.NUMERICAL_CONTINUOUS)), MagnitudeBucketTokenizer) |
|
|
| def test_discrete(self): |
| assert isinstance(create_field_tokenizer(FieldSpec("q", FieldType.NUMERICAL_DISCRETE, max_value=10)), DiscreteNumericalTokenizer) |
|
|
| def test_calendar(self): |
| assert isinstance(create_field_tokenizer(FieldSpec("t", FieldType.TEMPORAL)), CalendarTokenizer) |
|
|
| def test_categorical(self): |
| assert isinstance(create_field_tokenizer(FieldSpec("c", FieldType.CATEGORICAL_FIXED, categories=["a"])), CategoricalTokenizer) |
|
|
| def test_text_none(self): |
| assert create_field_tokenizer(FieldSpec("d", FieldType.TEXT)) is None |
|
|
|
|
| class TestPredefinedSchemas: |
| def test_finance(self): |
| assert FINANCE_SCHEMA.name == "finance" |
| assert len(FINANCE_SCHEMA.fields) == 4 |
|
|
| def test_ecommerce(self): |
| assert ECOMMERCE_SCHEMA.name == "ecommerce" |
| assert len(ECOMMERCE_SCHEMA.fields) == 6 |
|
|
| def test_healthcare(self): |
| assert HEALTHCARE_SCHEMA.name == "healthcare" |
| assert len(HEALTHCARE_SCHEMA.fields) == 6 |
|
|
| def test_nubank_97(self): |
| domain_tokens = sum(f.token_count for f in FINANCE_SCHEMA.fields) |
| assert domain_tokens == 97 |
|
|
|
|
| class TestDomainTokenizerBuilder: |
| @pytest.fixture |
| def events(self): |
| return [ |
| {"amount_sign": 79.99, "amount": 79.99, |
| "timestamp": datetime(2025, 3, 15, 14, 30), "description": "AMAZON"}, |
| {"amount_sign": -200.0, "amount": -200.0, |
| "timestamp": datetime(2025, 3, 16, 9, 15), "description": "SALARY"}, |
| {"amount_sign": 12.50, "amount": 12.50, |
| "timestamp": datetime(2025, 3, 17, 18, 45), "description": "UBER"}, |
| ] |
|
|
| @pytest.fixture |
| def corpus(self): |
| return ["AMAZON", "SALARY", "UBER", "GROCERY", "NETFLIX"] * 20 |
|
|
| def test_create(self): |
| assert not DomainTokenizerBuilder(FINANCE_SCHEMA).is_fitted |
|
|
| def test_fit(self, events): |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(events) |
| assert b.is_fitted |
|
|
| def test_tokenize_event(self, events): |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(events) |
| tokens = b.tokenize_event(events[0]) |
| assert len(tokens) >= 7 |
| assert tokens[0].startswith("[AMT_SIGN_") |
|
|
| def test_tokenize_sequence(self, events): |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(events) |
| tokens = b.tokenize_sequence(events) |
| assert tokens[0] == "[BOS]" |
| assert tokens[-1] == "[EOS]" |
| assert tokens.count(FINANCE_SCHEMA.event_separator) == 2 |
|
|
| def test_build(self, events, corpus): |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(events) |
| hf = b.build(text_corpus=corpus, bpe_vocab_size=300) |
| assert hf.pad_token == "[PAD]" |
| assert hf.convert_tokens_to_ids("[AMT_SIGN_POS]") != hf.unk_token_id |
|
|
| def test_end_to_end(self, events, corpus): |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(events) |
| hf = b.build(text_corpus=corpus, bpe_vocab_size=300) |
| enc = b.encode_sequence(events, hf, max_length=128) |
| assert len(enc["input_ids"]) == 128 |
| assert sum(1 for m in enc["attention_mask"] if m == 1) > 10 |
|
|
| def test_stats(self, events): |
| b = DomainTokenizerBuilder(FINANCE_SCHEMA) |
| b.fit(events) |
| s = b.get_stats() |
| assert s["schema_name"] == "finance" |
| assert s["is_fitted"] |
|
|
| def test_unfitted_raises(self): |
| with pytest.raises(RuntimeError): |
| DomainTokenizerBuilder(FINANCE_SCHEMA).build() |
|
|
|
|
| class TestEcommerceBuilder: |
| def test_full(self): |
| events = [ |
| {"event_type": "view", "price": 29.99, "quantity": 1, |
| "category": "electronics", "timestamp": datetime(2025, 3, 15, 10, 0), |
| "product_title": "Mouse"}, |
| {"event_type": "purchase", "price": 29.99, "quantity": 2, |
| "category": "electronics", "timestamp": datetime(2025, 3, 15, 10, 10), |
| "product_title": "Mouse"}, |
| ] |
| b = DomainTokenizerBuilder(ECOMMERCE_SCHEMA) |
| b.fit(events) |
| hf = b.build(text_corpus=["Mouse", "Keyboard"] * 20, bpe_vocab_size=200) |
| enc = b.encode_sequence(events, hf, max_length=256) |
| assert sum(1 for m in enc["attention_mask"] if m == 1) > 10 |
|
|