Spaces:
Sleeping
Sleeping
| """ | |
| Unit tests for pure-logic functions in src/classifier/train.py. | |
| No model checkpoint required. All tests run in milliseconds. | |
| Run: pytest tests/classifier/test_train.py | |
| """ | |
| from src.classifier.model import DOMAIN_LABELS | |
| from src.classifier.train import _build_supplement, _fill, _map_product | |
| # --------------------------------------------------------------------------- | |
| # _map_product | |
| # --------------------------------------------------------------------------- | |
| def test_map_product_exact_banking(): | |
| assert _map_product("Checking or savings account") == "banking" | |
| def test_map_product_exact_cibil(): | |
| assert _map_product("Credit card") == "cibil" | |
| def test_map_product_keyword_fallback_banking(): | |
| assert _map_product("Some new mortgage product") == "banking" | |
| def test_map_product_keyword_fallback_cibil(): | |
| assert _map_product("My credit card debt collection service") == "cibil" | |
| def test_map_product_unknown_returns_none(): | |
| assert _map_product("Exotic product never seen before") is None | |
| def test_map_product_strips_whitespace(): | |
| assert _map_product(" Mortgage ") == "banking" | |
| # --------------------------------------------------------------------------- | |
| # _build_supplement | |
| # --------------------------------------------------------------------------- | |
| def test_build_supplement_all_six_domains_present(): | |
| import random | |
| ds = _build_supplement(n_per_class=10, seed=42) | |
| label_values = set(ds["labels"]) | |
| from src.classifier.model import DOMAIN2ID | |
| assert label_values == set(DOMAIN2ID.values()) | |
| def test_build_supplement_count_per_class(): | |
| from collections import Counter | |
| ds = _build_supplement(n_per_class=20, seed=42) | |
| counts = Counter(ds["labels"]) | |
| assert all(c == 20 for c in counts.values()), f"Uneven counts: {counts}" | |
| def test_build_supplement_deterministic(): | |
| ds1 = _build_supplement(n_per_class=10, seed=42) | |
| ds2 = _build_supplement(n_per_class=10, seed=42) | |
| assert ds1["text"] == ds2["text"] | |
| def test_build_supplement_no_unfilled_placeholders(): | |
| ds = _build_supplement(n_per_class=10, seed=42) | |
| placeholders = ["{amount}", "{date}", "{ref}", "{days}"] | |
| for text in ds["text"]: | |
| for ph in placeholders: | |
| assert ph not in text, f"Unfilled placeholder {ph!r} in: {text!r}" | |