""" Tests for DocMind Attribution Parser. """ import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from pipeline.attribution import ( parse_attributed_response, strip_unattributed, ) class TestParseAttributedResponse: """Test parsing of LLM responses with [SOURCE: chunk_id] tags.""" def test_single_sentence_with_source(self): response = "The policy covers employees. [SOURCE: doc1_c0001]" result = parse_attributed_response(response) assert len(result) == 1 assert result[0].chunk_id == "doc1_c0001" assert "policy" in result[0].text def test_multiple_sentences(self): response = ( "The policy covers employees. [SOURCE: doc1_c0001] " "Benefits include health insurance. [SOURCE: doc1_c0002]" ) result = parse_attributed_response(response) assert len(result) == 2 assert result[0].chunk_id == "doc1_c0001" assert result[1].chunk_id == "doc1_c0002" def test_insufficient_context(self): response = "INSUFFICIENT_CONTEXT" result = parse_attributed_response(response) assert len(result) == 0 def test_empty_response(self): result = parse_attributed_response("") assert len(result) == 0 def test_no_source_tags(self): response = "This has no source citations at all." result = parse_attributed_response(response) # Should parse as sentence with no chunk_id assert len(result) >= 1 assert result[0].chunk_id is None class TestStripUnattributed: """Test filtering of invalid citations.""" def test_keeps_valid_citations(self): from pipeline.attribution import AttributedSentence sentences = [ AttributedSentence(text="Valid claim.", chunk_id="c001", raw_text=""), AttributedSentence(text="Another valid.", chunk_id="c002", raw_text=""), ] valid_ids = {"c001", "c002", "c003"} result = strip_unattributed(sentences, valid_ids) assert len(result) == 2 def test_removes_invalid_citations(self): from pipeline.attribution import AttributedSentence sentences = [ AttributedSentence(text="Valid.", chunk_id="c001", raw_text=""), AttributedSentence(text="Invalid.", chunk_id="c999", raw_text=""), ] valid_ids = {"c001", "c002"} result = strip_unattributed(sentences, valid_ids) assert len(result) == 1 assert result[0].chunk_id == "c001" def test_removes_uncited_sentences(self): from pipeline.attribution import AttributedSentence sentences = [ AttributedSentence(text="No citation.", chunk_id=None, raw_text=""), AttributedSentence(text="Has citation.", chunk_id="c001", raw_text=""), ] valid_ids = {"c001"} result = strip_unattributed(sentences, valid_ids) assert len(result) == 1 def test_all_invalid(self): from pipeline.attribution import AttributedSentence sentences = [ AttributedSentence(text="Bad.", chunk_id="c999", raw_text=""), AttributedSentence(text="Also bad.", chunk_id=None, raw_text=""), ] valid_ids = {"c001"} result = strip_unattributed(sentences, valid_ids) assert len(result) == 0 def test_empty_input(self): result = strip_unattributed([], {"c001"}) assert len(result) == 0 # ── Run tests ─────────────────────────────────────────────────────── if __name__ == "__main__": for cls in [TestParseAttributedResponse, TestStripUnattributed]: print(f"\n{cls.__name__}:") test = cls() passed = failed = 0 for method_name in sorted(dir(test)): if not method_name.startswith("test_"): continue try: getattr(test, method_name)() print(f" [PASS] {method_name}") passed += 1 except AssertionError as e: print(f" [FAIL] {method_name}: {e}") failed += 1 except Exception as e: print(f" [ERROR] {method_name}: {type(e).__name__}: {e}") failed += 1 print(f" {passed} passed, {failed} failed")