guide / tests /ner /test_train.py
Saravanakumar R
openspec add tests for all dl models
b2ef214
Raw
History Blame Contribute Delete
4.82 kB
"""
Unit tests for pure-logic functions in src/ner/train.py.
No model checkpoint required. All tests run in milliseconds.
Run: pytest tests/ner/test_train.py
"""
from src.ner.model import LABEL2ID
from src.ner.train import (
_assign_bio_labels,
_extract_slots,
_fill_template,
_word_tokenize,
)
# ---------------------------------------------------------------------------
# _extract_slots
# ---------------------------------------------------------------------------
def test_extract_slots_single():
assert _extract_slots("{ORG} failed to deliver") == ["ORG"]
def test_extract_slots_multiple_in_order():
assert _extract_slots("{ORG} charged {AMOUNT} on {DATE}") == ["ORG", "AMOUNT", "DATE"]
def test_extract_slots_no_slots():
assert _extract_slots("This is a plain sentence") == []
# ---------------------------------------------------------------------------
# _fill_template
# ---------------------------------------------------------------------------
def test_fill_template_span_offsets_are_correct():
template = "{ORG} charged {AMOUNT}"
slot_values = {"ORG": "Flipkart", "AMOUNT": "₹4,299"}
sentence, spans = _fill_template(template, slot_values)
for span in spans:
extracted = sentence[span["start"]:span["end"]]
assert extracted == slot_values[span["label"]], (
f"Offset mismatch for {span['label']!r}: got {extracted!r}"
)
def test_fill_template_spans_non_overlapping_and_ordered():
template = "{ORG} charged {AMOUNT} on {DATE}"
slot_values = {"ORG": "Flipkart", "AMOUNT": "₹4,299", "DATE": "12 March 2024"}
_, spans = _fill_template(template, slot_values)
for i in range(len(spans) - 1):
assert spans[i]["end"] <= spans[i + 1]["start"], "Spans overlap"
assert spans[i]["start"] < spans[i + 1]["start"], "Spans not in order"
def test_fill_template_span_text_matches_slot_value():
template = "{ORG} charged {AMOUNT}"
slot_values = {"ORG": "HDFC Bank", "AMOUNT": "₹1,200"}
sentence, spans = _fill_template(template, slot_values)
span_map = {s["label"]: sentence[s["start"]:s["end"]] for s in spans}
assert span_map["ORG"] == "HDFC Bank"
assert span_map["AMOUNT"] == "₹1,200"
# ---------------------------------------------------------------------------
# _word_tokenize
# ---------------------------------------------------------------------------
def test_word_tokenize_offsets_round_trip():
sentence = "Flipkart charged ₹4,299"
tokens = _word_tokenize(sentence)
for word, start, end in tokens:
assert sentence[start:end] == word, f"Offset mismatch for {word!r}"
def test_word_tokenize_punctuation_preserved():
tokens = _word_tokenize("Flipkart.")
assert len(tokens) == 1
word, start, end = tokens[0]
assert word == "Flipkart."
assert start == 0
assert end == 9
def test_word_tokenize_multiple_words():
tokens = _word_tokenize("I filed a complaint")
words = [w for w, _, _ in tokens]
assert words == ["I", "filed", "a", "complaint"]
# ---------------------------------------------------------------------------
# _assign_bio_labels
# ---------------------------------------------------------------------------
def _make_words_and_spans(sentence, entity_text, label):
"""Helper: build word tokens and a single entity span."""
start = sentence.index(entity_text)
end = start + len(entity_text)
words = _word_tokenize(sentence)
spans = [{"start": start, "end": end, "label": label}]
return words, spans
def test_assign_bio_labels_output_length():
sentence = "Flipkart charged ₹4,299 without authorization"
words = _word_tokenize(sentence)
spans = [{"start": 0, "end": 8, "label": "ORG"}]
labels = _assign_bio_labels(words, spans)
assert len(labels) == len(words)
def test_assign_bio_labels_first_word_gets_b():
sentence = "Flipkart charged ₹4,299"
words, spans = _make_words_and_spans(sentence, "Flipkart", "ORG")
labels = _assign_bio_labels(words, spans)
assert labels[0] == LABEL2ID["B-ORG"]
def test_assign_bio_labels_multi_word_entity():
sentence = "State Bank of India charged a fee"
words, spans = _make_words_and_spans(sentence, "State Bank of India", "ORG")
labels = _assign_bio_labels(words, spans)
assert labels[0] == LABEL2ID["B-ORG"]
assert labels[1] == LABEL2ID["I-ORG"]
assert labels[2] == LABEL2ID["I-ORG"]
assert labels[3] == LABEL2ID["I-ORG"]
def test_assign_bio_labels_outside_words_are_o():
sentence = "Flipkart charged me"
words, spans = _make_words_and_spans(sentence, "Flipkart", "ORG")
labels = _assign_bio_labels(words, spans)
# "charged" and "me" are outside the span
assert labels[1] == LABEL2ID["O"]
assert labels[2] == LABEL2ID["O"]