""" Unit tests for _aggregate_spans() in src/ner/model.py. Uses EvidenceNER.__new__() to bypass checkpoint loading — no model required. Run: pytest tests/ner/test_model.py """ from src.ner.model import EvidenceNER, LABEL2ID def _ner(): """Return an EvidenceNER instance with no checkpoint loaded.""" return EvidenceNER.__new__(EvidenceNER) # --------------------------------------------------------------------------- # Happy path # --------------------------------------------------------------------------- def test_aggregate_spans_single_entity(): """B-AMOUNT followed by I-AMOUNT produces one AMOUNT entity.""" ner = _ner() text = "₹4,299 overcharged" offset_mapping = [(0, 0), (0, 6), (6, 7), (7, 14), (0, 0)] pred_ids = [ LABEL2ID["O"], LABEL2ID["B-AMOUNT"], LABEL2ID["I-AMOUNT"], LABEL2ID["O"], LABEL2ID["O"], ] confs = [0.99, 0.91, 0.88, 0.95, 0.99] entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs) assert len(entities) == 1 assert entities[0].label == "AMOUNT" assert entities[0].start == 0 assert entities[0].end == 7 def test_aggregate_spans_two_entities(): """B-AMOUNT…O…B-REF_ID produces two separate entities.""" ner = _ner() text = "₹4,299 for OD-123" offset_mapping = [ (0, 0), (0, 6), # ₹4,299 (7, 10), # for → O (11, 17), # OD-123 (0, 0), ] pred_ids = [ LABEL2ID["O"], LABEL2ID["B-AMOUNT"], LABEL2ID["O"], LABEL2ID["B-REF_ID"], LABEL2ID["O"], ] confs = [0.99, 0.91, 0.95, 0.82, 0.99] entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs) labels = [e.label for e in entities] assert "AMOUNT" in labels assert "REF_ID" in labels assert len(entities) == 2 # --------------------------------------------------------------------------- # Edge cases — broken sequences # --------------------------------------------------------------------------- def test_aggregate_spans_orphan_i_tag_dropped(): """I-AMOUNT with no preceding B-AMOUNT must not produce an entity.""" ner = _ner() text = "some text here" offset_mapping = [(0, 0), (0, 4), (5, 9), (10, 14), (0, 0)] pred_ids = [ LABEL2ID["O"], LABEL2ID["I-AMOUNT"], # orphan — no preceding B- LABEL2ID["O"], LABEL2ID["O"], LABEL2ID["O"], ] confs = [0.99, 0.85, 0.95, 0.95, 0.99] entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs) assert entities == [] def test_aggregate_spans_mismatched_i_closes_span(): """B-AMOUNT followed by I-DATE should flush AMOUNT and drop I-DATE.""" ner = _ner() text = "₹4,299 today" offset_mapping = [(0, 0), (0, 6), (7, 12), (0, 0)] pred_ids = [ LABEL2ID["O"], LABEL2ID["B-AMOUNT"], LABEL2ID["I-DATE"], # mismatch — different type LABEL2ID["O"], ] confs = [0.99, 0.91, 0.80, 0.99] entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs) assert len(entities) == 1 assert entities[0].label == "AMOUNT" # --------------------------------------------------------------------------- # Special tokens # --------------------------------------------------------------------------- def test_aggregate_spans_special_tokens_skipped(): """(0,0) offset entries for CLS/SEP must not contribute to any entity.""" ner = _ner() text = "Flipkart" # CLS at 0, real token, SEP at end offset_mapping = [(0, 0), (0, 8), (0, 0)] pred_ids = [ LABEL2ID["B-ORG"], # CLS — should be skipped LABEL2ID["B-ORG"], # real token LABEL2ID["I-ORG"], # SEP — should be skipped ] confs = [0.99, 0.91, 0.99] entities = ner._aggregate_spans(text, offset_mapping, pred_ids, confs) assert len(entities) == 1 assert entities[0].start == 0 assert entities[0].end == 8 # --------------------------------------------------------------------------- # Empty input guard # --------------------------------------------------------------------------- def test_extract_empty_string_returns_empty_list(): """EvidenceNER.extract() guard fires before tokenizer — safe via __new__.""" ner = _ner() assert ner.extract("") == [] def test_extract_whitespace_only_returns_empty_list(): """EvidenceNER.extract() guard fires for whitespace-only input.""" ner = _ner() assert ner.extract(" ") == []