guide / tests /classifier /test_train.py
Saravanakumar R
openspec add tests for all dl models
b2ef214
Raw
History Blame Contribute Delete
2.3 kB
"""
Unit tests for pure-logic functions in src/classifier/train.py.
No model checkpoint required. All tests run in milliseconds.
Run: pytest tests/classifier/test_train.py
"""
from src.classifier.model import DOMAIN_LABELS
from src.classifier.train import _build_supplement, _fill, _map_product
# ---------------------------------------------------------------------------
# _map_product
# ---------------------------------------------------------------------------
def test_map_product_exact_banking():
assert _map_product("Checking or savings account") == "banking"
def test_map_product_exact_cibil():
assert _map_product("Credit card") == "cibil"
def test_map_product_keyword_fallback_banking():
assert _map_product("Some new mortgage product") == "banking"
def test_map_product_keyword_fallback_cibil():
assert _map_product("My credit card debt collection service") == "cibil"
def test_map_product_unknown_returns_none():
assert _map_product("Exotic product never seen before") is None
def test_map_product_strips_whitespace():
assert _map_product(" Mortgage ") == "banking"
# ---------------------------------------------------------------------------
# _build_supplement
# ---------------------------------------------------------------------------
def test_build_supplement_all_six_domains_present():
import random
ds = _build_supplement(n_per_class=10, seed=42)
label_values = set(ds["labels"])
from src.classifier.model import DOMAIN2ID
assert label_values == set(DOMAIN2ID.values())
def test_build_supplement_count_per_class():
from collections import Counter
ds = _build_supplement(n_per_class=20, seed=42)
counts = Counter(ds["labels"])
assert all(c == 20 for c in counts.values()), f"Uneven counts: {counts}"
def test_build_supplement_deterministic():
ds1 = _build_supplement(n_per_class=10, seed=42)
ds2 = _build_supplement(n_per_class=10, seed=42)
assert ds1["text"] == ds2["text"]
def test_build_supplement_no_unfilled_placeholders():
ds = _build_supplement(n_per_class=10, seed=42)
placeholders = ["{amount}", "{date}", "{ref}", "{days}"]
for text in ds["text"]:
for ph in placeholders:
assert ph not in text, f"Unfilled placeholder {ph!r} in: {text!r}"