guide / tests /ner /test_model.py
Saravanakumar R
openspec add tests for all dl models
b2ef214
Raw
History Blame Contribute Delete
4.54 kB
"""
Unit tests for _aggregate_spans() in src/ner/model.py.
Uses EvidenceNER.__new__() to bypass checkpoint loading — no model required.
Run: pytest tests/ner/test_model.py
"""
from src.ner.model import EvidenceNER, LABEL2ID
def _ner():
"""Return an EvidenceNER instance with no checkpoint loaded."""
return EvidenceNER.__new__(EvidenceNER)
# ---------------------------------------------------------------------------
# Happy path
# ---------------------------------------------------------------------------
def test_aggregate_spans_single_entity():
"""B-AMOUNT followed by I-AMOUNT produces one AMOUNT entity."""
ner = _ner()
text = "₹4,299 overcharged"
offset_mapping = [(0, 0), (0, 6), (6, 7), (7, 14), (0, 0)]
pred_ids = [
LABEL2ID["O"],
LABEL2ID["B-AMOUNT"],
LABEL2ID["I-AMOUNT"],
LABEL2ID["O"],
LABEL2ID["O"],
]
confs = [0.99, 0.91, 0.88, 0.95, 0.99]
entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs)
assert len(entities) == 1
assert entities[0].label == "AMOUNT"
assert entities[0].start == 0
assert entities[0].end == 7
def test_aggregate_spans_two_entities():
"""B-AMOUNT…O…B-REF_ID produces two separate entities."""
ner = _ner()
text = "₹4,299 for OD-123"
offset_mapping = [
(0, 0),
(0, 6), # ₹4,299
(7, 10), # for → O
(11, 17), # OD-123
(0, 0),
]
pred_ids = [
LABEL2ID["O"],
LABEL2ID["B-AMOUNT"],
LABEL2ID["O"],
LABEL2ID["B-REF_ID"],
LABEL2ID["O"],
]
confs = [0.99, 0.91, 0.95, 0.82, 0.99]
entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs)
labels = [e.label for e in entities]
assert "AMOUNT" in labels
assert "REF_ID" in labels
assert len(entities) == 2
# ---------------------------------------------------------------------------
# Edge cases — broken sequences
# ---------------------------------------------------------------------------
def test_aggregate_spans_orphan_i_tag_dropped():
"""I-AMOUNT with no preceding B-AMOUNT must not produce an entity."""
ner = _ner()
text = "some text here"
offset_mapping = [(0, 0), (0, 4), (5, 9), (10, 14), (0, 0)]
pred_ids = [
LABEL2ID["O"],
LABEL2ID["I-AMOUNT"], # orphan — no preceding B-
LABEL2ID["O"],
LABEL2ID["O"],
LABEL2ID["O"],
]
confs = [0.99, 0.85, 0.95, 0.95, 0.99]
entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs)
assert entities == []
def test_aggregate_spans_mismatched_i_closes_span():
"""B-AMOUNT followed by I-DATE should flush AMOUNT and drop I-DATE."""
ner = _ner()
text = "₹4,299 today"
offset_mapping = [(0, 0), (0, 6), (7, 12), (0, 0)]
pred_ids = [
LABEL2ID["O"],
LABEL2ID["B-AMOUNT"],
LABEL2ID["I-DATE"], # mismatch — different type
LABEL2ID["O"],
]
confs = [0.99, 0.91, 0.80, 0.99]
entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs)
assert len(entities) == 1
assert entities[0].label == "AMOUNT"
# ---------------------------------------------------------------------------
# Special tokens
# ---------------------------------------------------------------------------
def test_aggregate_spans_special_tokens_skipped():
"""(0,0) offset entries for CLS/SEP must not contribute to any entity."""
ner = _ner()
text = "Flipkart"
# CLS at 0, real token, SEP at end
offset_mapping = [(0, 0), (0, 8), (0, 0)]
pred_ids = [
LABEL2ID["B-ORG"], # CLS — should be skipped
LABEL2ID["B-ORG"], # real token
LABEL2ID["I-ORG"], # SEP — should be skipped
]
confs = [0.99, 0.91, 0.99]
entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs)
assert len(entities) == 1
assert entities[0].start == 0
assert entities[0].end == 8
# ---------------------------------------------------------------------------
# Empty input guard
# ---------------------------------------------------------------------------
def test_extract_empty_string_returns_empty_list():
"""EvidenceNER.extract() guard fires before tokenizer — safe via __new__."""
ner = _ner()
assert ner.extract("") == []
def test_extract_whitespace_only_returns_empty_list():
"""EvidenceNER.extract() guard fires for whitespace-only input."""
ner = _ner()
assert ner.extract(" ") == []