""" Unit tests for pure-logic functions in src/ner/train.py. No model checkpoint required. All tests run in milliseconds. Run: pytest tests/ner/test_train.py """ from src.ner.model import LABEL2ID from src.ner.train import ( _assign_bio_labels, _extract_slots, _fill_template, _word_tokenize, ) # --------------------------------------------------------------------------- # _extract_slots # --------------------------------------------------------------------------- def test_extract_slots_single(): assert _extract_slots("{ORG} failed to deliver") == ["ORG"] def test_extract_slots_multiple_in_order(): assert _extract_slots("{ORG} charged {AMOUNT} on {DATE}") == ["ORG", "AMOUNT", "DATE"] def test_extract_slots_no_slots(): assert _extract_slots("This is a plain sentence") == [] # --------------------------------------------------------------------------- # _fill_template # --------------------------------------------------------------------------- def test_fill_template_span_offsets_are_correct(): template = "{ORG} charged {AMOUNT}" slot_values = {"ORG": "Flipkart", "AMOUNT": "₹4,299"} sentence, spans = _fill_template(template, slot_values) for span in spans: extracted = sentence[span["start"]:span["end"]] assert extracted == slot_values[span["label"]], ( f"Offset mismatch for {span['label']!r}: got {extracted!r}" ) def test_fill_template_spans_non_overlapping_and_ordered(): template = "{ORG} charged {AMOUNT} on {DATE}" slot_values = {"ORG": "Flipkart", "AMOUNT": "₹4,299", "DATE": "12 March 2024"} _, spans = _fill_template(template, slot_values) for i in range(len(spans) - 1): assert spans[i]["end"] <= spans[i + 1]["start"], "Spans overlap" assert spans[i]["start"] < spans[i + 1]["start"], "Spans not in order" def test_fill_template_span_text_matches_slot_value(): template = "{ORG} charged {AMOUNT}" slot_values = {"ORG": "HDFC Bank", "AMOUNT": "₹1,200"} sentence, spans = _fill_template(template, slot_values) span_map = {s["label"]: sentence[s["start"]:s["end"]] for s in spans} assert span_map["ORG"] == "HDFC Bank" assert span_map["AMOUNT"] == "₹1,200" # --------------------------------------------------------------------------- # _word_tokenize # --------------------------------------------------------------------------- def test_word_tokenize_offsets_round_trip(): sentence = "Flipkart charged ₹4,299" tokens = _word_tokenize(sentence) for word, start, end in tokens: assert sentence[start:end] == word, f"Offset mismatch for {word!r}" def test_word_tokenize_punctuation_preserved(): tokens = _word_tokenize("Flipkart.") assert len(tokens) == 1 word, start, end = tokens[0] assert word == "Flipkart." assert start == 0 assert end == 9 def test_word_tokenize_multiple_words(): tokens = _word_tokenize("I filed a complaint") words = [w for w, _, _ in tokens] assert words == ["I", "filed", "a", "complaint"] # --------------------------------------------------------------------------- # _assign_bio_labels # --------------------------------------------------------------------------- def _make_words_and_spans(sentence, entity_text, label): """Helper: build word tokens and a single entity span.""" start = sentence.index(entity_text) end = start + len(entity_text) words = _word_tokenize(sentence) spans = [{"start": start, "end": end, "label": label}] return words, spans def test_assign_bio_labels_output_length(): sentence = "Flipkart charged ₹4,299 without authorization" words = _word_tokenize(sentence) spans = [{"start": 0, "end": 8, "label": "ORG"}] labels = _assign_bio_labels(words, spans) assert len(labels) == len(words) def test_assign_bio_labels_first_word_gets_b(): sentence = "Flipkart charged ₹4,299" words, spans = _make_words_and_spans(sentence, "Flipkart", "ORG") labels = _assign_bio_labels(words, spans) assert labels[0] == LABEL2ID["B-ORG"] def test_assign_bio_labels_multi_word_entity(): sentence = "State Bank of India charged a fee" words, spans = _make_words_and_spans(sentence, "State Bank of India", "ORG") labels = _assign_bio_labels(words, spans) assert labels[0] == LABEL2ID["B-ORG"] assert labels[1] == LABEL2ID["I-ORG"] assert labels[2] == LABEL2ID["I-ORG"] assert labels[3] == LABEL2ID["I-ORG"] def test_assign_bio_labels_outside_words_are_o(): sentence = "Flipkart charged me" words, spans = _make_words_and_spans(sentence, "Flipkart", "ORG") labels = _assign_bio_labels(words, spans) # "charged" and "me" are outside the span assert labels[1] == LABEL2ID["O"] assert labels[2] == LABEL2ID["O"]