"""Tests for the structured citations module. These tests exercise the parse → manipulate → render flow without going through any LLM call. They cover sequential numbering, phantom citations, escaped brackets, grouped/range syntax, references parsing, and edge cases like citations next to punctuation. """ import re import unittest from citations import ( CitationDocument, Reference, apply_finalization, body_and_references_match, finalize_with_canonical, has_escaped_brackets, has_grouped_or_range_syntax, has_phantom_citations, is_sequential_in_first_appearance_order, parse_document, ) from tests.helpers import make_source def _body_numbers(text: str) -> list[int]: body = text.split("### References")[0] return [int(n) for n in re.findall(r"\[(\d+)\]", body)] def _ref_numbers(text: str) -> list[int]: if "### References" not in text: return [] refs = text.split("### References", 1)[1] return [int(n) for n in re.findall(r"\[(\d+)\]", refs)] def _unique_in_order(nums): seen = set() out = [] for n in nums: if n not in seen: seen.add(n) out.append(n) return out class ParseDocumentTests(unittest.TestCase): """parse_document(text) → CitationDocument""" def test_empty_input_returns_empty_document(self): doc = parse_document("") self.assertEqual(doc.body, "") self.assertEqual(doc.references, []) def test_none_input_returns_empty_document(self): doc = parse_document(None) # type: ignore[arg-type] self.assertEqual(doc.body, "") self.assertEqual(doc.references, []) def test_body_only_no_references(self): doc = parse_document("Plain body with [1] citation.") self.assertEqual(doc.body, "Plain body with [1] citation.") self.assertEqual(doc.references, []) def test_body_and_references_section_split(self): text = "Body [1].\n\n### References\n\n[1] Source A" doc = parse_document(text) self.assertEqual(doc.body, "Body [1].") self.assertEqual(len(doc.references), 1) self.assertEqual(doc.references[0].number, 1) self.assertEqual(doc.references[0].text, "Source A") def test_plain_references_heading_no_markdown(self): text = "Body [1].\n\nReferences\n\n[1] Source A" doc = parse_document(text) self.assertIn("Body [1]", doc.body) self.assertEqual(len(doc.references), 1) def test_references_heading_case_insensitive(self): text = "Body [1].\n\n## REFERENCES\n\n[1] Source A" doc = parse_document(text) self.assertEqual(len(doc.references), 1) def test_references_with_trailing_colon(self): text = "Body [1].\n\n### References:\n\n[1] Source A" doc = parse_document(text) self.assertEqual(len(doc.references), 1) def test_escaped_brackets_are_unescaped(self): text = "Body \\[1\\]\\[2\\].\n\n### References\n\n\\[1\\] Source A" doc = parse_document(text) self.assertEqual(doc.body, "Body [1][2].") self.assertEqual(doc.references[0].number, 1) def test_multiple_reference_entries_parsed(self): text = ( "Body [1] [2] [3].\n\n" "### References\n\n" "[1] First source\n\n" "[2] Second source\n\n" "[3] Third source" ) doc = parse_document(text) self.assertEqual(len(doc.references), 3) self.assertEqual(doc.references[2].text, "Third source") def test_reference_lines_not_double_spaced(self): text = ( "Body [1].\n\n### References\n[1] First source\n[2] Stale" ) doc = parse_document(text) # Without blank-line separator we still get at least one self.assertTrue(len(doc.references) >= 1) class CitedNumbersTests(unittest.TestCase): """CitationDocument.cited_numbers and first_appearance_order""" def test_no_citations_returns_empty(self): doc = parse_document("Plain body, no markers.") self.assertEqual(doc.cited_numbers(), []) self.assertEqual(doc.first_appearance_order(), []) def test_single_citation(self): doc = parse_document("Body with [3] citation.") self.assertEqual(doc.cited_numbers(), [3]) def test_grouped_comma_citation(self): doc = parse_document("Body with [1, 2, 3] citation.") self.assertEqual(doc.cited_numbers(), [1, 2, 3]) def test_grouped_semicolon_citation(self): doc = parse_document("Body with [1; 2] citation.") self.assertEqual(doc.cited_numbers(), [1, 2]) def test_range_citation(self): doc = parse_document("Body with [2-5] citation.") self.assertEqual(doc.cited_numbers(), [2, 3, 4, 5]) def test_adjacent_blocks(self): doc = parse_document("Body with [1][2][3] citations.") self.assertEqual(doc.cited_numbers(), [1, 2, 3]) def test_first_appearance_order_dedupes(self): doc = parse_document("First [3]. Second [1]. Third [3]. Fourth [2].") self.assertEqual(doc.first_appearance_order(), [3, 1, 2]) def test_citations_with_punctuation_around(self): doc = parse_document("Sentence ends here[1]. Another sentence,[2] continues.") self.assertEqual(doc.cited_numbers(), [1, 2]) class NormalizeTests(unittest.TestCase): """CitationDocument.normalize(source_count)""" def test_normalize_drops_phantom_citations(self): doc = parse_document("Body [1][2][5][6].") result = doc.normalize(source_count=2) self.assertEqual(result.cited_numbers(), [1, 2]) def test_normalize_renumbers_to_sequential(self): doc = parse_document("First [3]. Second [1]. Third [2].") result = doc.normalize(source_count=3) # First-cited gets [1], next new gets [2], etc. self.assertEqual(result.first_appearance_order(), [1, 2, 3]) self.assertIn("First [1]", result.body) self.assertIn("Second [2]", result.body) self.assertIn("Third [3]", result.body) def test_normalize_handles_groups(self): doc = parse_document("Body [3, 1, 5].") result = doc.normalize(source_count=3) # 5 is dropped, 3 and 1 become 1 and 2 in first-appearance order. self.assertIn("[1][2]", result.body) self.assertNotIn("[5]", result.body) def test_normalize_handles_range(self): doc = parse_document("Body [1-3].") result = doc.normalize(source_count=3) self.assertEqual(result.cited_numbers(), [1, 2, 3]) def test_normalize_zero_sources_strips_all(self): doc = parse_document("Body [1][2][3].") result = doc.normalize(source_count=0) self.assertEqual(result.cited_numbers(), []) self.assertNotIn("[", result.body) def test_normalize_only_phantom_strips_all(self): doc = parse_document("Body [99][100].") result = doc.normalize(source_count=3) self.assertEqual(result.cited_numbers(), []) def test_normalize_negative_source_count(self): doc = parse_document("Body [1][2].") result = doc.normalize(source_count=-5) self.assertEqual(result.cited_numbers(), []) def test_normalize_cleans_punctuation_after_removal(self): doc = parse_document("First sentence [99]. Second sentence.") result = doc.normalize(source_count=3) # The orphaned period should still be there, no double space. self.assertNotIn(" ", result.body) self.assertIn(".", result.body) def test_normalize_dedupes_within_block(self): doc = parse_document("Body [1, 1, 1].") result = doc.normalize(source_count=3) self.assertEqual(result.cited_numbers(), [1]) class ApplyFinalizationTests(unittest.TestCase): """End-to-end apply_finalization(text, sources) → final answer string""" def test_basic_sequential_output(self): text = ( "First fact [1].\n\n" "Second fact [2].\n\n" "### References\n\n" "[1] stale\n\n" "[2] stale" ) sources = [ make_source("Source A", 2021), make_source("Source B", 2022), ] result = apply_finalization(text, sources) body_nums = _body_numbers(result) ref_nums = _ref_numbers(result) self.assertEqual(body_nums, [1, 2]) self.assertEqual(ref_nums, [1, 2]) self.assertIn("Source A", result) self.assertIn("Source B", result) self.assertNotIn("stale", result) def test_phantom_citations_dropped(self): text = ( "Time crystals are exotic [1][2].\n\n" "Wilczek proposed them [3].\n\n" "Lab observations [4][6].\n\n" # 6 is phantom "MBL stabilises them [5][6].\n\n" # 5 and 6 are phantom "### References\n\n[1] stale\n\n[2] stale\n\n[3] stale\n\n[4] stale" ) sources = [ make_source("Wilczek 2012", 2012), make_source("Khemani 2016", 2016), make_source("Else 2016", 2016), make_source("Zhang 2017", 2017), ] result = apply_finalization(text, sources) body_nums = sorted(set(_body_numbers(result))) ref_nums = sorted(set(_ref_numbers(result))) self.assertEqual(body_nums, ref_nums) self.assertTrue(max(body_nums, default=0) <= 4) self.assertNotIn("[5]", result.split("### References")[0]) self.assertNotIn("[6]", result.split("### References")[0]) def test_escaped_brackets_round_trip(self): text = ( "Body \\[1\\]\\[2\\].\n\n" "Second paragraph cites \\[3\\] more.\n\n" "### References\n\n" "\\[1\\] stale\n\n\\[2\\] stale\n\n\\[3\\] stale" ) sources = [ make_source("Source A", 2020), make_source("Source B", 2021), make_source("Source C", 2022), ] result = apply_finalization(text, sources) self.assertNotIn("\\[", result) self.assertNotIn("\\]", result) self.assertEqual(_body_numbers(result), [1, 2, 3]) self.assertEqual(_ref_numbers(result), [1, 2, 3]) def test_first_appearance_order_is_canonical(self): text = ( "Citing source 3 first [3].\n\n" "Then source 1 [1].\n\n" "Then source 2 [2]." ) sources = [ make_source("First in list", 2021), make_source("Second in list", 2022), make_source("Third in list", 2023), ] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("Citing source 3 first [1]", body) self.assertIn("Then source 1 [2]", body) self.assertIn("Then source 2 [3]", body) # Reference [1] must point to the source originally numbered [3] (Third in list) self.assertIn("[1]", result.split("### References")[1]) self.assertIn("Third in list", result.split("### References")[1].split("[2]")[0]) def test_grouped_citation_normalised_to_adjacent_blocks(self): text = "Body has grouped citation [1, 2, 3]." sources = [ make_source("A", 2020), make_source("B", 2021), make_source("C", 2022), ] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2][3]", body) self.assertNotIn("[1, 2, 3]", body) def test_range_citation_expanded(self): text = "Body has range citation [1-3]." sources = [ make_source("A", 2020), make_source("B", 2021), make_source("C", 2022), ] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2][3]", body) self.assertNotIn("[1-3]", body) def test_no_sources_strips_all_citations(self): text = "Body [1] with [2] citations.\n\n### References\n\n[1] stale" result = apply_finalization(text, sources=[]) self.assertNotIn("[1]", result) self.assertNotIn("[2]", result) def test_no_citations_in_text_returns_body(self): text = "Plain body with no citations at all." result = apply_finalization(text, sources=[make_source("A", 2020)]) self.assertEqual(result.strip(), "Plain body with no citations at all.") def test_only_phantom_citations_strips_all(self): text = "Body [99] with [100] citations." result = apply_finalization(text, sources=[make_source("A", 2020)]) self.assertNotIn("[99]", result) self.assertNotIn("[100]", result) def test_duplicate_citations_in_same_paragraph(self): text = "First [1] then [1] again [1]." result = apply_finalization(text, [make_source("A", 2020)]) body = result.split("### References")[0] self.assertEqual(body.count("[1]"), 3) self.assertEqual(_ref_numbers(result), [1]) def test_multiple_substantive_paragraphs_get_backfilled(self): text = ( "Opening paragraph cites the source explicitly [1].\n\n" "This middle paragraph has enough words to count as substantive but lacks a citation entirely.\n\n" "Closing paragraph also has support [2]." ) sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) # Middle paragraph should now have a citation borrowed from a neighbour. body = result.split("### References")[0] self.assertRegex( body, r"This middle paragraph has enough words to count as substantive but lacks a citation entirely\.\s*\[\d+\]", ) def test_short_paragraph_not_backfilled(self): text = ( "Cited [1].\n\n" "Tiny.\n\n" "Another cited paragraph [2]." ) sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) # 'Tiny.' is too short to be substantive, no citation added. self.assertNotRegex(result, r"Tiny\.\s*\[\d+\]") def test_headings_are_not_backfilled(self): text = ( "Body paragraph cites a source [1].\n\n" "## A heading should not get a citation\n\n" "Closing paragraph has another [2]." ) sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) self.assertNotRegex(result, r"## A heading should not get a citation\s*\[\d+\]") def test_list_items_are_not_backfilled(self): text = ( "Cited paragraph here [1].\n\n" "- list item should not be backfilled even if long enough\n\n" "Another cited paragraph [2]." ) sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) self.assertNotRegex(result, r"list item should not be backfilled[^\n]*\[\d+\]") def test_dedupes_within_grouped_block(self): text = "Body [1, 1, 2, 2]." sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2]", body) def test_phantom_in_groups_filtered(self): text = "Body has [2, 99, 1, 100, 3]." sources = [ make_source("A", 2020), make_source("B", 2021), make_source("C", 2022), ] result = apply_finalization(text, sources) body = result.split("### References")[0] # 99 and 100 should be dropped self.assertNotIn("[99]", body) self.assertNotIn("[100]", body) # The valid ones (2, 1, 3 in that order) become [1][2][3] self.assertIn("[1][2][3]", body) class DiagnosticsTests(unittest.TestCase): """has_phantom / has_grouped / has_escaped / sequential / matching""" def test_has_phantom_detects_out_of_range(self): self.assertTrue(has_phantom_citations("Body [3] [99]", source_count=5)) self.assertFalse(has_phantom_citations("Body [3] [5]", source_count=5)) def test_has_phantom_detects_zero_or_negative(self): self.assertTrue(has_phantom_citations("Body [0]", source_count=5)) def test_has_grouped_detects_comma_syntax(self): self.assertTrue(has_grouped_or_range_syntax("Body [1, 2]")) self.assertFalse(has_grouped_or_range_syntax("Body [1][2]")) def test_has_grouped_detects_range_syntax(self): self.assertTrue(has_grouped_or_range_syntax("Body [1-3]")) def test_has_grouped_detects_semicolon_syntax(self): self.assertTrue(has_grouped_or_range_syntax("Body [1; 2]")) def test_has_escaped_brackets(self): self.assertTrue(has_escaped_brackets("Body \\[1\\]")) self.assertFalse(has_escaped_brackets("Body [1]")) def test_sequential_in_first_appearance_order(self): self.assertTrue(is_sequential_in_first_appearance_order("First [1] then [2] then [3]")) self.assertFalse(is_sequential_in_first_appearance_order("First [2] then [1] then [3]")) self.assertTrue(is_sequential_in_first_appearance_order("Body [1] [1] [2] [2]")) # dupes ok def test_body_and_references_match(self): text = "Body [1] [2].\n\n### References\n\n[1] A\n\n[2] B" self.assertTrue(body_and_references_match(text)) def test_body_and_references_mismatch(self): text = "Body [1] [2].\n\n### References\n\n[1] A" self.assertFalse(body_and_references_match(text)) class FirstAppearanceOrderWithinTests(unittest.TestCase): """first_appearance_order_within(N) — order limited to valid range""" def test_filters_out_of_range(self): doc = parse_document("Body [3] [99] [1] [2]") self.assertEqual(doc.first_appearance_order_within(3), [3, 1, 2]) def test_dedupes(self): doc = parse_document("Body [1] [1] [2] [2] [1]") self.assertEqual(doc.first_appearance_order_within(5), [1, 2]) def test_empty(self): doc = parse_document("Body with nothing") self.assertEqual(doc.first_appearance_order_within(5), []) class RealWorldLLMOutputTests(unittest.TestCase): """Stress tests with messy real-world LLM output patterns.""" def test_mixed_escaped_and_unescaped(self): text = ( "First paragraph cites both \\[1\\] and [2].\n\n" "Second paragraph mixes [3]\\[4\\]." ) sources = [make_source(f"Source {c}", 2020) for c in "ABCD"] result = apply_finalization(text, sources) self.assertNotIn("\\[", result) self.assertEqual(_body_numbers(result), [1, 2, 3, 4]) def test_citations_in_bulleted_lists(self): text = ( "Treatment options include:\n" "- First-line therapy [1]\n" "- Second-line therapy [2]\n" "- Specialist option [3]\n\n" "Closing paragraph that summarises [1][2]." ) sources = [make_source(f"Source {c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) self.assertEqual(set(_body_numbers(result)), {1, 2, 3}) def test_citations_in_numbered_lists(self): text = ( "1. First step [1]\n" "2. Second step [2]\n" "3. Third step [3]\n" ) sources = [make_source(f"Source {c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) self.assertEqual(set(_body_numbers(result)), {1, 2, 3}) def test_citations_separated_by_pre_period_space(self): # Some LLMs put a space before the period: "claim [1] ." text = "Some claim [1] . Another [2] !" sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) self.assertEqual(_body_numbers(result), [1, 2]) def test_citations_at_end_of_paragraph(self): text = "First paragraph here.\n\nSecond paragraph ends with citation [1]." sources = [make_source("A", 2020)] result = apply_finalization(text, sources) # First paragraph (substantive) should be backfilled body = result.split("### References")[0] self.assertIn("[1]", body) def test_consecutive_citations_with_no_space(self): text = "Body[1][2][3]ends suddenly." sources = [make_source(f"Source {c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) self.assertEqual(_body_numbers(result), [1, 2, 3]) def test_citation_inside_inline_code(self): # `[1]` inside backticks — currently treated as a citation; # this test documents the behaviour, not necessarily desired. text = "Use the function `parse([1])` and cite [1]." sources = [make_source("A", 2020)] result = apply_finalization(text, sources) # Both [1] tokens should resolve to source 1 body = result.split("### References")[0] self.assertIn("[1]", body) def test_long_paragraph_with_multiple_citations(self): text = ( "Time crystals are a phase of matter that breaks discrete time " "translation symmetry [1]. Wilczek originally proposed continuous " "time crystals [2], though no-go theorems ruled them out [3][4]. " "Discrete versions emerged from periodically driven systems [5]. " "Many-body localisation stabilises them [6]." ) sources = [make_source(f"Source {c}", 2017 + i) for i, c in enumerate("ABCDEF")] result = apply_finalization(text, sources) self.assertEqual(set(_body_numbers(result)), {1, 2, 3, 4, 5, 6}) self.assertEqual(set(_ref_numbers(result)), {1, 2, 3, 4, 5, 6}) def test_completely_phantom_text_falls_back_gracefully(self): # Body has only out-of-range citations and the LLM also generated stale refs text = ( "Body cites phantom sources [99] and [100].\n\n" "### References\n\n[99] Bogus\n\n[100] Bogus" ) sources = [make_source("Real source", 2020)] result = apply_finalization(text, sources) self.assertNotIn("[99]", result) self.assertNotIn("[100]", result) self.assertNotIn("Bogus", result) def test_semicolon_grouped_citation(self): text = "Body [1; 3; 2]." sources = [make_source(f"S{c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) body = result.split("### References")[0] # Order in source: 1, 3, 2 → first appearance is [1] (=A), [2] (=C), [3] (=B) self.assertIn("[1][2][3]", body) # Reference [1] = original source 1 (Source A) self.assertIn("SA", result.split("### References")[1]) def test_finalized_output_is_self_consistent(self): """After finalization, body and references should reference the same set of sources, in the same order, with first body citation = [1].""" text = "Body [3] then [1] then [2]." sources = [make_source("A", 2020), make_source("B", 2021), make_source("C", 2022)] result = apply_finalization(text, sources) body_nums = _body_numbers(result) ref_nums = _ref_numbers(result) # Sets must match self.assertEqual(set(body_nums), set(ref_nums)) # Body must start at [1] in first-appearance order self.assertEqual(_unique_in_order(body_nums), list(range(1, len(set(body_nums)) + 1))) # References must be 1..N sequential self.assertEqual(ref_nums, sorted(set(ref_nums))) def test_finalize_with_canonical_is_truly_idempotent(self): """finalize_with_canonical returns (text, sources) that round-trip cleanly.""" text = "Body [3] then [1] then [2]." sources = [make_source("A", 2020), make_source("B", 2021), make_source("C", 2022)] once_text, once_sources = finalize_with_canonical(text, sources) twice_text, twice_sources = finalize_with_canonical(once_text, once_sources) self.assertEqual(once_text, twice_text) self.assertEqual(once_sources, twice_sources) # And the canonical sources should be in body's first-appearance order # Original [3]=C is first cited, so canonical[0] = C self.assertEqual(once_sources[0]["title"], "C") self.assertEqual(once_sources[1]["title"], "A") self.assertEqual(once_sources[2]["title"], "B") def test_idempotent_with_phantom_input(self): text = "Body [1][99][2][100]" sources = [make_source("A", 2020), make_source("B", 2021)] once = apply_finalization(text, sources) twice = apply_finalization(once, sources) self.assertEqual(once, twice) def test_blank_lines_in_references_preserved_for_parsing(self): text = ( "Body [1] [2].\n\n" "### References\n\n" "[1] First entry\n" "with line break\n\n" "[2] Second entry" ) doc = parse_document(text) self.assertEqual(len(doc.references), 2) def test_unicode_in_titles(self): text = "Body [1] cites Müller study." sources = [make_source("Müller study on β-blockers", 2020)] result = apply_finalization(text, sources) self.assertIn("Müller", result) self.assertIn("β-blockers", result) class CitationOrderTests(unittest.TestCase): """Targeted tests for the 'citations must be in order 1,2,3,4' invariant.""" def test_ascending_citations_unchanged(self): text = "First [1]. Second [2]. Third [3]. Fourth [4]." sources = [make_source(f"S{c}", 2020) for c in "ABCD"] result = apply_finalization(text, sources) self.assertEqual(_body_numbers(result), [1, 2, 3, 4]) self.assertEqual(_ref_numbers(result), [1, 2, 3, 4]) def test_descending_citations_renumbered(self): text = "First [4]. Second [3]. Third [2]. Fourth [1]." sources = [make_source(f"S{c}", 2020) for c in "ABCD"] result = apply_finalization(text, sources) # First-cited (originally 4) becomes 1; etc. self.assertEqual(_body_numbers(result), [1, 2, 3, 4]) # Reference [1] corresponds to original source 4 (S D) first_ref = result.split("### References")[1].split("[2]")[0] self.assertIn("SD", first_ref) def test_random_citations_renumbered(self): text = "First [2]. Second [4]. Third [1]. Fourth [3]." sources = [make_source(f"S{c}", 2020) for c in "ABCD"] result = apply_finalization(text, sources) self.assertEqual(_body_numbers(result), [1, 2, 3, 4]) # First cited (orig 2) → [1] → S B; orig 4 → [2] → S D; orig 1 → [3] → S A; orig 3 → [4] → S C refs = result.split("### References")[1] # Get individual ref entries ref_entries = re.findall(r"\[(\d+)\][^\[]*?(SA|SB|SC|SD)", refs) renumber_map = dict(ref_entries) self.assertEqual(renumber_map.get("1"), "SB") self.assertEqual(renumber_map.get("2"), "SD") self.assertEqual(renumber_map.get("3"), "SA") self.assertEqual(renumber_map.get("4"), "SC") def test_subset_of_sources_used(self): text = "Only one cited [3]." sources = [make_source(f"S{c}", 2020) for c in "ABCDEF"] result = apply_finalization(text, sources) # Only one citation: [1] mapping to original source 3 (S C) self.assertEqual(_body_numbers(result), [1]) self.assertEqual(_ref_numbers(result), [1]) self.assertIn("SC", result.split("### References")[1]) def test_repeat_citations_preserve_first_appearance(self): text = "First [3]. Then [3] again. Now [1]. And [3] yet again. Finally [2]." sources = [make_source(f"S{c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) body = result.split("### References")[0] # Order of unique first appearances: 3, 1, 2 → renumbered to 1, 2, 3 first_unique = _unique_in_order(_body_numbers(result)) self.assertEqual(first_unique, [1, 2, 3]) def test_same_citation_in_grouped_block(self): text = "Body [1, 1, 1] then [1]." sources = [make_source("A", 2020)] result = apply_finalization(text, sources) # Should dedupe within block but allow repeats across blocks body = result.split("### References")[0] # First block becomes [1] (single), second block also [1] self.assertEqual(body.count("[1]"), 2) class CanonicalSourceFilterTests(unittest.TestCase): """Tests for the canonical_sources ≤ all_sources invariant. When the polish LLM cites only a subset of retrieved sources, `finalize_with_canonical` returns only the cited ones. The report, committed output, and judge panel must all use the same filtered set so N references is consistent everywhere. """ def test_canonical_drops_uncited_sources(self): # 8 sources retrieved, polish only cites 7 of them text = ( "Body uses [1], [2], [3], [4], [5], [6], and [7].\n\n" "### References\n\n" "[1] A\n[2] B\n[3] C\n[4] D\n[5] E\n[6] F\n[7] G" ) sources = [make_source(c, 2020) for c in "ABCDEFGH"] # 8 sources _canon_text, canon_sources = finalize_with_canonical(text, sources) # Only 7 are cited, so canonical should have 7 (not 8) self.assertEqual(len(canon_sources), 7) # And they preserve first-appearance order self.assertEqual([s["title"] for s in canon_sources], ["A", "B", "C", "D", "E", "F", "G"]) def test_canonical_preserves_all_when_all_cited(self): text = "Uses [1], [2], [3].\n\n### References\n\n[1] A\n[2] B\n[3] C" sources = [make_source(c, 2020) for c in "ABC"] _text, canon = finalize_with_canonical(text, sources) self.assertEqual(len(canon), 3) def test_canonical_with_out_of_order_partial_citation(self): # 5 sources retrieved, polish cites only [3] and [5] text = "Claim [3] and finding [5].\n\n### References\n\n[3] C\n[5] E" sources = [make_source(c, 2020) for c in "ABCDE"] _text, canon = finalize_with_canonical(text, sources) # Only 2 cited → canonical has 2 self.assertEqual(len(canon), 2) # First cited was [3]=C → canonical[0]=C self.assertEqual(canon[0]["title"], "C") self.assertEqual(canon[1]["title"], "E") def test_body_references_match_canonical_count(self): """Body citation count should equal canonical_sources length after finalize.""" text = "Body [1][2][3][4].\n\n### References\n\n[1] A\n[2] B\n[3] C\n[4] D" sources = [make_source(c, 2020) for c in "ABCDEF"] # 6 retrieved canon_text, canon_sources = finalize_with_canonical(text, sources) body_nums = sorted(set(_body_numbers(canon_text))) self.assertEqual(len(body_nums), len(canon_sources)) self.assertEqual(body_nums, list(range(1, len(canon_sources) + 1))) class BodyStrippingChainTests(unittest.TestCase): """Test the finalize_with_canonical → parse_document → .body chain used by main.py to strip the References section before SSE.""" def test_body_only_has_no_references_section(self): text = "Body [1] then [2].\n\n### References\n\n[1] Smith (2020).\n\n[2] Jones (2021)." sources = [make_source("A", 2020), make_source("B", 2021)] canon_text, canon_sources = finalize_with_canonical(text, sources) body = parse_document(canon_text).body.strip() self.assertNotIn("References", body) self.assertNotIn("[1] Smith", body) self.assertIn("[1]", body) # inline citations preserved self.assertIn("[2]", body) def test_body_preserves_inline_citations(self): text = "Claim [2] supported by evidence [1].\n\n### References\n\n[1] Ref.\n\n[2] Ref." sources = [make_source("A", 2020), make_source("B", 2021)] canon_text, _ = finalize_with_canonical(text, sources) body = parse_document(canon_text).body.strip() self.assertIn("[1]", body) self.assertIn("[2]", body) def test_canonical_sources_match_body_order(self): text = "First [3] then [1].\n\n### References\n\n[1] Ref1.\n\n[3] Ref3." sources = [make_source("A", 2020), make_source("B", 2021), make_source("C", 2022)] _, canon_sources = finalize_with_canonical(text, sources) # [3]=C appeared first in body → canon_sources[0] = C self.assertEqual(canon_sources[0]["title"], "C") self.assertEqual(canon_sources[1]["title"], "A") def test_no_sources_returns_body_without_citations(self): text = "Body [1] and [2]." canon_text, canon_sources = finalize_with_canonical(text, []) body = parse_document(canon_text).body.strip() self.assertNotIn("[1]", body) self.assertEqual(canon_sources, []) def test_round_trip_idempotent(self): text = "Body [1][2].\n\n### References\n\n[1] Smith. (2020).\n\n[2] Jones. (2021)." sources = [make_source("A", 2020), make_source("B", 2021)] t1, s1 = finalize_with_canonical(text, sources) b1 = parse_document(t1).body.strip() # Second pass on the body-only + canonical sources should be stable t2, s2 = finalize_with_canonical(b1 + "\n\n### References\n\n" + "\n\n".join(f"[{i+1}] {s['title']}" for i, s in enumerate(s1)), s1) b2 = parse_document(t2).body.strip() self.assertEqual(b1, b2) class AdjacentCitationSortTests(unittest.TestCase): """Adjacent [N][M] blocks should be sorted ascending (academic convention).""" def test_reversed_pair_sorted(self): text = "Some claim [2][1]." sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2]", body) def test_triple_reversed_sorted(self): text = "Evidence [3][1][2]." sources = [make_source(f"S{c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2][3]", body) def test_already_sorted_unchanged(self): text = "Evidence [1][2][3]." sources = [make_source(f"S{c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2][3]", body) def test_duplicates_in_adjacent_deduped_and_sorted(self): text = "Claim [3][1][3][2]." sources = [make_source(f"S{c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1][2][3]", body) def test_separate_groups_sorted_independently(self): text = "First [2][1] and second [3][1]." sources = [make_source(f"S{c}", 2020) for c in "ABC"] result = apply_finalization(text, sources) body = result.split("### References")[0] # After renumber: orig 2→1, orig 1→2, orig 3→3 # First group: [1][2] (sorted). Second: [3][2] → [2][3] (sorted). self.assertIn("[1][2]", body) self.assertIn("[2][3]", body) def test_single_citation_unchanged(self): text = "Claim [1]." sources = [make_source("A", 2020)] result = apply_finalization(text, sources) body = result.split("### References")[0] self.assertIn("[1]", body) def test_non_adjacent_not_merged(self): """[2] some text [1] — not adjacent, so no sorting.""" text = "First [2] then some words [1]." sources = [make_source("A", 2020), make_source("B", 2021)] result = apply_finalization(text, sources) body = result.split("### References")[0] # After renumbering: [2] appeared first → [1], [1] appeared second → [2] nums = _body_numbers(result) self.assertEqual(nums, [1, 2]) if __name__ == "__main__": unittest.main()