crcs-live / tests /test_citations_structured.py
Nipun's picture
v0.27.2: Fix report/committed output reference count mismatch, CRCS full form
e0edcb1
"""Tests for the structured citations module.
These tests exercise the parse β†’ manipulate β†’ render flow without going through
any LLM call. They cover sequential numbering, phantom citations, escaped
brackets, grouped/range syntax, references parsing, and edge cases like
citations next to punctuation.
"""
import re
import unittest
from citations import (
CitationDocument,
Reference,
apply_finalization,
body_and_references_match,
finalize_with_canonical,
has_escaped_brackets,
has_grouped_or_range_syntax,
has_phantom_citations,
is_sequential_in_first_appearance_order,
parse_document,
)
from tests.helpers import make_source
def _body_numbers(text: str) -> list[int]:
body = text.split("### References")[0]
return [int(n) for n in re.findall(r"\[(\d+)\]", body)]
def _ref_numbers(text: str) -> list[int]:
if "### References" not in text:
return []
refs = text.split("### References", 1)[1]
return [int(n) for n in re.findall(r"\[(\d+)\]", refs)]
def _unique_in_order(nums):
seen = set()
out = []
for n in nums:
if n not in seen:
seen.add(n)
out.append(n)
return out
class ParseDocumentTests(unittest.TestCase):
"""parse_document(text) β†’ CitationDocument"""
def test_empty_input_returns_empty_document(self):
doc = parse_document("")
self.assertEqual(doc.body, "")
self.assertEqual(doc.references, [])
def test_none_input_returns_empty_document(self):
doc = parse_document(None) # type: ignore[arg-type]
self.assertEqual(doc.body, "")
self.assertEqual(doc.references, [])
def test_body_only_no_references(self):
doc = parse_document("Plain body with [1] citation.")
self.assertEqual(doc.body, "Plain body with [1] citation.")
self.assertEqual(doc.references, [])
def test_body_and_references_section_split(self):
text = "Body [1].\n\n### References\n\n[1] Source A"
doc = parse_document(text)
self.assertEqual(doc.body, "Body [1].")
self.assertEqual(len(doc.references), 1)
self.assertEqual(doc.references[0].number, 1)
self.assertEqual(doc.references[0].text, "Source A")
def test_plain_references_heading_no_markdown(self):
text = "Body [1].\n\nReferences\n\n[1] Source A"
doc = parse_document(text)
self.assertIn("Body [1]", doc.body)
self.assertEqual(len(doc.references), 1)
def test_references_heading_case_insensitive(self):
text = "Body [1].\n\n## REFERENCES\n\n[1] Source A"
doc = parse_document(text)
self.assertEqual(len(doc.references), 1)
def test_references_with_trailing_colon(self):
text = "Body [1].\n\n### References:\n\n[1] Source A"
doc = parse_document(text)
self.assertEqual(len(doc.references), 1)
def test_escaped_brackets_are_unescaped(self):
text = "Body \\[1\\]\\[2\\].\n\n### References\n\n\\[1\\] Source A"
doc = parse_document(text)
self.assertEqual(doc.body, "Body [1][2].")
self.assertEqual(doc.references[0].number, 1)
def test_multiple_reference_entries_parsed(self):
text = (
"Body [1] [2] [3].\n\n"
"### References\n\n"
"[1] First source\n\n"
"[2] Second source\n\n"
"[3] Third source"
)
doc = parse_document(text)
self.assertEqual(len(doc.references), 3)
self.assertEqual(doc.references[2].text, "Third source")
def test_reference_lines_not_double_spaced(self):
text = (
"Body [1].\n\n### References\n[1] First source\n[2] Stale"
)
doc = parse_document(text)
# Without blank-line separator we still get at least one
self.assertTrue(len(doc.references) >= 1)
class CitedNumbersTests(unittest.TestCase):
"""CitationDocument.cited_numbers and first_appearance_order"""
def test_no_citations_returns_empty(self):
doc = parse_document("Plain body, no markers.")
self.assertEqual(doc.cited_numbers(), [])
self.assertEqual(doc.first_appearance_order(), [])
def test_single_citation(self):
doc = parse_document("Body with [3] citation.")
self.assertEqual(doc.cited_numbers(), [3])
def test_grouped_comma_citation(self):
doc = parse_document("Body with [1, 2, 3] citation.")
self.assertEqual(doc.cited_numbers(), [1, 2, 3])
def test_grouped_semicolon_citation(self):
doc = parse_document("Body with [1; 2] citation.")
self.assertEqual(doc.cited_numbers(), [1, 2])
def test_range_citation(self):
doc = parse_document("Body with [2-5] citation.")
self.assertEqual(doc.cited_numbers(), [2, 3, 4, 5])
def test_adjacent_blocks(self):
doc = parse_document("Body with [1][2][3] citations.")
self.assertEqual(doc.cited_numbers(), [1, 2, 3])
def test_first_appearance_order_dedupes(self):
doc = parse_document("First [3]. Second [1]. Third [3]. Fourth [2].")
self.assertEqual(doc.first_appearance_order(), [3, 1, 2])
def test_citations_with_punctuation_around(self):
doc = parse_document("Sentence ends here[1]. Another sentence,[2] continues.")
self.assertEqual(doc.cited_numbers(), [1, 2])
class NormalizeTests(unittest.TestCase):
"""CitationDocument.normalize(source_count)"""
def test_normalize_drops_phantom_citations(self):
doc = parse_document("Body [1][2][5][6].")
result = doc.normalize(source_count=2)
self.assertEqual(result.cited_numbers(), [1, 2])
def test_normalize_renumbers_to_sequential(self):
doc = parse_document("First [3]. Second [1]. Third [2].")
result = doc.normalize(source_count=3)
# First-cited gets [1], next new gets [2], etc.
self.assertEqual(result.first_appearance_order(), [1, 2, 3])
self.assertIn("First [1]", result.body)
self.assertIn("Second [2]", result.body)
self.assertIn("Third [3]", result.body)
def test_normalize_handles_groups(self):
doc = parse_document("Body [3, 1, 5].")
result = doc.normalize(source_count=3)
# 5 is dropped, 3 and 1 become 1 and 2 in first-appearance order.
self.assertIn("[1][2]", result.body)
self.assertNotIn("[5]", result.body)
def test_normalize_handles_range(self):
doc = parse_document("Body [1-3].")
result = doc.normalize(source_count=3)
self.assertEqual(result.cited_numbers(), [1, 2, 3])
def test_normalize_zero_sources_strips_all(self):
doc = parse_document("Body [1][2][3].")
result = doc.normalize(source_count=0)
self.assertEqual(result.cited_numbers(), [])
self.assertNotIn("[", result.body)
def test_normalize_only_phantom_strips_all(self):
doc = parse_document("Body [99][100].")
result = doc.normalize(source_count=3)
self.assertEqual(result.cited_numbers(), [])
def test_normalize_negative_source_count(self):
doc = parse_document("Body [1][2].")
result = doc.normalize(source_count=-5)
self.assertEqual(result.cited_numbers(), [])
def test_normalize_cleans_punctuation_after_removal(self):
doc = parse_document("First sentence [99]. Second sentence.")
result = doc.normalize(source_count=3)
# The orphaned period should still be there, no double space.
self.assertNotIn(" ", result.body)
self.assertIn(".", result.body)
def test_normalize_dedupes_within_block(self):
doc = parse_document("Body [1, 1, 1].")
result = doc.normalize(source_count=3)
self.assertEqual(result.cited_numbers(), [1])
class ApplyFinalizationTests(unittest.TestCase):
"""End-to-end apply_finalization(text, sources) β†’ final answer string"""
def test_basic_sequential_output(self):
text = (
"First fact [1].\n\n"
"Second fact [2].\n\n"
"### References\n\n"
"[1] stale\n\n"
"[2] stale"
)
sources = [
make_source("Source A", 2021),
make_source("Source B", 2022),
]
result = apply_finalization(text, sources)
body_nums = _body_numbers(result)
ref_nums = _ref_numbers(result)
self.assertEqual(body_nums, [1, 2])
self.assertEqual(ref_nums, [1, 2])
self.assertIn("Source A", result)
self.assertIn("Source B", result)
self.assertNotIn("stale", result)
def test_phantom_citations_dropped(self):
text = (
"Time crystals are exotic [1][2].\n\n"
"Wilczek proposed them [3].\n\n"
"Lab observations [4][6].\n\n" # 6 is phantom
"MBL stabilises them [5][6].\n\n" # 5 and 6 are phantom
"### References\n\n[1] stale\n\n[2] stale\n\n[3] stale\n\n[4] stale"
)
sources = [
make_source("Wilczek 2012", 2012),
make_source("Khemani 2016", 2016),
make_source("Else 2016", 2016),
make_source("Zhang 2017", 2017),
]
result = apply_finalization(text, sources)
body_nums = sorted(set(_body_numbers(result)))
ref_nums = sorted(set(_ref_numbers(result)))
self.assertEqual(body_nums, ref_nums)
self.assertTrue(max(body_nums, default=0) <= 4)
self.assertNotIn("[5]", result.split("### References")[0])
self.assertNotIn("[6]", result.split("### References")[0])
def test_escaped_brackets_round_trip(self):
text = (
"Body \\[1\\]\\[2\\].\n\n"
"Second paragraph cites \\[3\\] more.\n\n"
"### References\n\n"
"\\[1\\] stale\n\n\\[2\\] stale\n\n\\[3\\] stale"
)
sources = [
make_source("Source A", 2020),
make_source("Source B", 2021),
make_source("Source C", 2022),
]
result = apply_finalization(text, sources)
self.assertNotIn("\\[", result)
self.assertNotIn("\\]", result)
self.assertEqual(_body_numbers(result), [1, 2, 3])
self.assertEqual(_ref_numbers(result), [1, 2, 3])
def test_first_appearance_order_is_canonical(self):
text = (
"Citing source 3 first [3].\n\n"
"Then source 1 [1].\n\n"
"Then source 2 [2]."
)
sources = [
make_source("First in list", 2021),
make_source("Second in list", 2022),
make_source("Third in list", 2023),
]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("Citing source 3 first [1]", body)
self.assertIn("Then source 1 [2]", body)
self.assertIn("Then source 2 [3]", body)
# Reference [1] must point to the source originally numbered [3] (Third in list)
self.assertIn("[1]", result.split("### References")[1])
self.assertIn("Third in list", result.split("### References")[1].split("[2]")[0])
def test_grouped_citation_normalised_to_adjacent_blocks(self):
text = "Body has grouped citation [1, 2, 3]."
sources = [
make_source("A", 2020),
make_source("B", 2021),
make_source("C", 2022),
]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2][3]", body)
self.assertNotIn("[1, 2, 3]", body)
def test_range_citation_expanded(self):
text = "Body has range citation [1-3]."
sources = [
make_source("A", 2020),
make_source("B", 2021),
make_source("C", 2022),
]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2][3]", body)
self.assertNotIn("[1-3]", body)
def test_no_sources_strips_all_citations(self):
text = "Body [1] with [2] citations.\n\n### References\n\n[1] stale"
result = apply_finalization(text, sources=[])
self.assertNotIn("[1]", result)
self.assertNotIn("[2]", result)
def test_no_citations_in_text_returns_body(self):
text = "Plain body with no citations at all."
result = apply_finalization(text, sources=[make_source("A", 2020)])
self.assertEqual(result.strip(), "Plain body with no citations at all.")
def test_only_phantom_citations_strips_all(self):
text = "Body [99] with [100] citations."
result = apply_finalization(text, sources=[make_source("A", 2020)])
self.assertNotIn("[99]", result)
self.assertNotIn("[100]", result)
def test_duplicate_citations_in_same_paragraph(self):
text = "First [1] then [1] again [1]."
result = apply_finalization(text, [make_source("A", 2020)])
body = result.split("### References")[0]
self.assertEqual(body.count("[1]"), 3)
self.assertEqual(_ref_numbers(result), [1])
def test_multiple_substantive_paragraphs_get_backfilled(self):
text = (
"Opening paragraph cites the source explicitly [1].\n\n"
"This middle paragraph has enough words to count as substantive but lacks a citation entirely.\n\n"
"Closing paragraph also has support [2]."
)
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
# Middle paragraph should now have a citation borrowed from a neighbour.
body = result.split("### References")[0]
self.assertRegex(
body,
r"This middle paragraph has enough words to count as substantive but lacks a citation entirely\.\s*\[\d+\]",
)
def test_short_paragraph_not_backfilled(self):
text = (
"Cited [1].\n\n"
"Tiny.\n\n"
"Another cited paragraph [2]."
)
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
# 'Tiny.' is too short to be substantive, no citation added.
self.assertNotRegex(result, r"Tiny\.\s*\[\d+\]")
def test_headings_are_not_backfilled(self):
text = (
"Body paragraph cites a source [1].\n\n"
"## A heading should not get a citation\n\n"
"Closing paragraph has another [2]."
)
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
self.assertNotRegex(result, r"## A heading should not get a citation\s*\[\d+\]")
def test_list_items_are_not_backfilled(self):
text = (
"Cited paragraph here [1].\n\n"
"- list item should not be backfilled even if long enough\n\n"
"Another cited paragraph [2]."
)
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
self.assertNotRegex(result, r"list item should not be backfilled[^\n]*\[\d+\]")
def test_dedupes_within_grouped_block(self):
text = "Body [1, 1, 2, 2]."
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2]", body)
def test_phantom_in_groups_filtered(self):
text = "Body has [2, 99, 1, 100, 3]."
sources = [
make_source("A", 2020),
make_source("B", 2021),
make_source("C", 2022),
]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
# 99 and 100 should be dropped
self.assertNotIn("[99]", body)
self.assertNotIn("[100]", body)
# The valid ones (2, 1, 3 in that order) become [1][2][3]
self.assertIn("[1][2][3]", body)
class DiagnosticsTests(unittest.TestCase):
"""has_phantom / has_grouped / has_escaped / sequential / matching"""
def test_has_phantom_detects_out_of_range(self):
self.assertTrue(has_phantom_citations("Body [3] [99]", source_count=5))
self.assertFalse(has_phantom_citations("Body [3] [5]", source_count=5))
def test_has_phantom_detects_zero_or_negative(self):
self.assertTrue(has_phantom_citations("Body [0]", source_count=5))
def test_has_grouped_detects_comma_syntax(self):
self.assertTrue(has_grouped_or_range_syntax("Body [1, 2]"))
self.assertFalse(has_grouped_or_range_syntax("Body [1][2]"))
def test_has_grouped_detects_range_syntax(self):
self.assertTrue(has_grouped_or_range_syntax("Body [1-3]"))
def test_has_grouped_detects_semicolon_syntax(self):
self.assertTrue(has_grouped_or_range_syntax("Body [1; 2]"))
def test_has_escaped_brackets(self):
self.assertTrue(has_escaped_brackets("Body \\[1\\]"))
self.assertFalse(has_escaped_brackets("Body [1]"))
def test_sequential_in_first_appearance_order(self):
self.assertTrue(is_sequential_in_first_appearance_order("First [1] then [2] then [3]"))
self.assertFalse(is_sequential_in_first_appearance_order("First [2] then [1] then [3]"))
self.assertTrue(is_sequential_in_first_appearance_order("Body [1] [1] [2] [2]")) # dupes ok
def test_body_and_references_match(self):
text = "Body [1] [2].\n\n### References\n\n[1] A\n\n[2] B"
self.assertTrue(body_and_references_match(text))
def test_body_and_references_mismatch(self):
text = "Body [1] [2].\n\n### References\n\n[1] A"
self.assertFalse(body_and_references_match(text))
class FirstAppearanceOrderWithinTests(unittest.TestCase):
"""first_appearance_order_within(N) β€” order limited to valid range"""
def test_filters_out_of_range(self):
doc = parse_document("Body [3] [99] [1] [2]")
self.assertEqual(doc.first_appearance_order_within(3), [3, 1, 2])
def test_dedupes(self):
doc = parse_document("Body [1] [1] [2] [2] [1]")
self.assertEqual(doc.first_appearance_order_within(5), [1, 2])
def test_empty(self):
doc = parse_document("Body with nothing")
self.assertEqual(doc.first_appearance_order_within(5), [])
class RealWorldLLMOutputTests(unittest.TestCase):
"""Stress tests with messy real-world LLM output patterns."""
def test_mixed_escaped_and_unescaped(self):
text = (
"First paragraph cites both \\[1\\] and [2].\n\n"
"Second paragraph mixes [3]\\[4\\]."
)
sources = [make_source(f"Source {c}", 2020) for c in "ABCD"]
result = apply_finalization(text, sources)
self.assertNotIn("\\[", result)
self.assertEqual(_body_numbers(result), [1, 2, 3, 4])
def test_citations_in_bulleted_lists(self):
text = (
"Treatment options include:\n"
"- First-line therapy [1]\n"
"- Second-line therapy [2]\n"
"- Specialist option [3]\n\n"
"Closing paragraph that summarises [1][2]."
)
sources = [make_source(f"Source {c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
self.assertEqual(set(_body_numbers(result)), {1, 2, 3})
def test_citations_in_numbered_lists(self):
text = (
"1. First step [1]\n"
"2. Second step [2]\n"
"3. Third step [3]\n"
)
sources = [make_source(f"Source {c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
self.assertEqual(set(_body_numbers(result)), {1, 2, 3})
def test_citations_separated_by_pre_period_space(self):
# Some LLMs put a space before the period: "claim [1] ."
text = "Some claim [1] . Another [2] !"
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
self.assertEqual(_body_numbers(result), [1, 2])
def test_citations_at_end_of_paragraph(self):
text = "First paragraph here.\n\nSecond paragraph ends with citation [1]."
sources = [make_source("A", 2020)]
result = apply_finalization(text, sources)
# First paragraph (substantive) should be backfilled
body = result.split("### References")[0]
self.assertIn("[1]", body)
def test_consecutive_citations_with_no_space(self):
text = "Body[1][2][3]ends suddenly."
sources = [make_source(f"Source {c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
self.assertEqual(_body_numbers(result), [1, 2, 3])
def test_citation_inside_inline_code(self):
# `[1]` inside backticks β€” currently treated as a citation;
# this test documents the behaviour, not necessarily desired.
text = "Use the function `parse([1])` and cite [1]."
sources = [make_source("A", 2020)]
result = apply_finalization(text, sources)
# Both [1] tokens should resolve to source 1
body = result.split("### References")[0]
self.assertIn("[1]", body)
def test_long_paragraph_with_multiple_citations(self):
text = (
"Time crystals are a phase of matter that breaks discrete time "
"translation symmetry [1]. Wilczek originally proposed continuous "
"time crystals [2], though no-go theorems ruled them out [3][4]. "
"Discrete versions emerged from periodically driven systems [5]. "
"Many-body localisation stabilises them [6]."
)
sources = [make_source(f"Source {c}", 2017 + i) for i, c in enumerate("ABCDEF")]
result = apply_finalization(text, sources)
self.assertEqual(set(_body_numbers(result)), {1, 2, 3, 4, 5, 6})
self.assertEqual(set(_ref_numbers(result)), {1, 2, 3, 4, 5, 6})
def test_completely_phantom_text_falls_back_gracefully(self):
# Body has only out-of-range citations and the LLM also generated stale refs
text = (
"Body cites phantom sources [99] and [100].\n\n"
"### References\n\n[99] Bogus\n\n[100] Bogus"
)
sources = [make_source("Real source", 2020)]
result = apply_finalization(text, sources)
self.assertNotIn("[99]", result)
self.assertNotIn("[100]", result)
self.assertNotIn("Bogus", result)
def test_semicolon_grouped_citation(self):
text = "Body [1; 3; 2]."
sources = [make_source(f"S{c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
# Order in source: 1, 3, 2 β†’ first appearance is [1] (=A), [2] (=C), [3] (=B)
self.assertIn("[1][2][3]", body)
# Reference [1] = original source 1 (Source A)
self.assertIn("SA", result.split("### References")[1])
def test_finalized_output_is_self_consistent(self):
"""After finalization, body and references should reference the same set
of sources, in the same order, with first body citation = [1]."""
text = "Body [3] then [1] then [2]."
sources = [make_source("A", 2020), make_source("B", 2021), make_source("C", 2022)]
result = apply_finalization(text, sources)
body_nums = _body_numbers(result)
ref_nums = _ref_numbers(result)
# Sets must match
self.assertEqual(set(body_nums), set(ref_nums))
# Body must start at [1] in first-appearance order
self.assertEqual(_unique_in_order(body_nums), list(range(1, len(set(body_nums)) + 1)))
# References must be 1..N sequential
self.assertEqual(ref_nums, sorted(set(ref_nums)))
def test_finalize_with_canonical_is_truly_idempotent(self):
"""finalize_with_canonical returns (text, sources) that round-trip cleanly."""
text = "Body [3] then [1] then [2]."
sources = [make_source("A", 2020), make_source("B", 2021), make_source("C", 2022)]
once_text, once_sources = finalize_with_canonical(text, sources)
twice_text, twice_sources = finalize_with_canonical(once_text, once_sources)
self.assertEqual(once_text, twice_text)
self.assertEqual(once_sources, twice_sources)
# And the canonical sources should be in body's first-appearance order
# Original [3]=C is first cited, so canonical[0] = C
self.assertEqual(once_sources[0]["title"], "C")
self.assertEqual(once_sources[1]["title"], "A")
self.assertEqual(once_sources[2]["title"], "B")
def test_idempotent_with_phantom_input(self):
text = "Body [1][99][2][100]"
sources = [make_source("A", 2020), make_source("B", 2021)]
once = apply_finalization(text, sources)
twice = apply_finalization(once, sources)
self.assertEqual(once, twice)
def test_blank_lines_in_references_preserved_for_parsing(self):
text = (
"Body [1] [2].\n\n"
"### References\n\n"
"[1] First entry\n"
"with line break\n\n"
"[2] Second entry"
)
doc = parse_document(text)
self.assertEqual(len(doc.references), 2)
def test_unicode_in_titles(self):
text = "Body [1] cites MΓΌller study."
sources = [make_source("MΓΌller study on Ξ²-blockers", 2020)]
result = apply_finalization(text, sources)
self.assertIn("MΓΌller", result)
self.assertIn("Ξ²-blockers", result)
class CitationOrderTests(unittest.TestCase):
"""Targeted tests for the 'citations must be in order 1,2,3,4' invariant."""
def test_ascending_citations_unchanged(self):
text = "First [1]. Second [2]. Third [3]. Fourth [4]."
sources = [make_source(f"S{c}", 2020) for c in "ABCD"]
result = apply_finalization(text, sources)
self.assertEqual(_body_numbers(result), [1, 2, 3, 4])
self.assertEqual(_ref_numbers(result), [1, 2, 3, 4])
def test_descending_citations_renumbered(self):
text = "First [4]. Second [3]. Third [2]. Fourth [1]."
sources = [make_source(f"S{c}", 2020) for c in "ABCD"]
result = apply_finalization(text, sources)
# First-cited (originally 4) becomes 1; etc.
self.assertEqual(_body_numbers(result), [1, 2, 3, 4])
# Reference [1] corresponds to original source 4 (S D)
first_ref = result.split("### References")[1].split("[2]")[0]
self.assertIn("SD", first_ref)
def test_random_citations_renumbered(self):
text = "First [2]. Second [4]. Third [1]. Fourth [3]."
sources = [make_source(f"S{c}", 2020) for c in "ABCD"]
result = apply_finalization(text, sources)
self.assertEqual(_body_numbers(result), [1, 2, 3, 4])
# First cited (orig 2) β†’ [1] β†’ S B; orig 4 β†’ [2] β†’ S D; orig 1 β†’ [3] β†’ S A; orig 3 β†’ [4] β†’ S C
refs = result.split("### References")[1]
# Get individual ref entries
ref_entries = re.findall(r"\[(\d+)\][^\[]*?(SA|SB|SC|SD)", refs)
renumber_map = dict(ref_entries)
self.assertEqual(renumber_map.get("1"), "SB")
self.assertEqual(renumber_map.get("2"), "SD")
self.assertEqual(renumber_map.get("3"), "SA")
self.assertEqual(renumber_map.get("4"), "SC")
def test_subset_of_sources_used(self):
text = "Only one cited [3]."
sources = [make_source(f"S{c}", 2020) for c in "ABCDEF"]
result = apply_finalization(text, sources)
# Only one citation: [1] mapping to original source 3 (S C)
self.assertEqual(_body_numbers(result), [1])
self.assertEqual(_ref_numbers(result), [1])
self.assertIn("SC", result.split("### References")[1])
def test_repeat_citations_preserve_first_appearance(self):
text = "First [3]. Then [3] again. Now [1]. And [3] yet again. Finally [2]."
sources = [make_source(f"S{c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
# Order of unique first appearances: 3, 1, 2 β†’ renumbered to 1, 2, 3
first_unique = _unique_in_order(_body_numbers(result))
self.assertEqual(first_unique, [1, 2, 3])
def test_same_citation_in_grouped_block(self):
text = "Body [1, 1, 1] then [1]."
sources = [make_source("A", 2020)]
result = apply_finalization(text, sources)
# Should dedupe within block but allow repeats across blocks
body = result.split("### References")[0]
# First block becomes [1] (single), second block also [1]
self.assertEqual(body.count("[1]"), 2)
class CanonicalSourceFilterTests(unittest.TestCase):
"""Tests for the canonical_sources ≀ all_sources invariant.
When the polish LLM cites only a subset of retrieved sources,
`finalize_with_canonical` returns only the cited ones. The report,
committed output, and judge panel must all use the same filtered
set so N references is consistent everywhere.
"""
def test_canonical_drops_uncited_sources(self):
# 8 sources retrieved, polish only cites 7 of them
text = (
"Body uses [1], [2], [3], [4], [5], [6], and [7].\n\n"
"### References\n\n"
"[1] A\n[2] B\n[3] C\n[4] D\n[5] E\n[6] F\n[7] G"
)
sources = [make_source(c, 2020) for c in "ABCDEFGH"] # 8 sources
_canon_text, canon_sources = finalize_with_canonical(text, sources)
# Only 7 are cited, so canonical should have 7 (not 8)
self.assertEqual(len(canon_sources), 7)
# And they preserve first-appearance order
self.assertEqual([s["title"] for s in canon_sources],
["A", "B", "C", "D", "E", "F", "G"])
def test_canonical_preserves_all_when_all_cited(self):
text = "Uses [1], [2], [3].\n\n### References\n\n[1] A\n[2] B\n[3] C"
sources = [make_source(c, 2020) for c in "ABC"]
_text, canon = finalize_with_canonical(text, sources)
self.assertEqual(len(canon), 3)
def test_canonical_with_out_of_order_partial_citation(self):
# 5 sources retrieved, polish cites only [3] and [5]
text = "Claim [3] and finding [5].\n\n### References\n\n[3] C\n[5] E"
sources = [make_source(c, 2020) for c in "ABCDE"]
_text, canon = finalize_with_canonical(text, sources)
# Only 2 cited β†’ canonical has 2
self.assertEqual(len(canon), 2)
# First cited was [3]=C β†’ canonical[0]=C
self.assertEqual(canon[0]["title"], "C")
self.assertEqual(canon[1]["title"], "E")
def test_body_references_match_canonical_count(self):
"""Body citation count should equal canonical_sources length after finalize."""
text = "Body [1][2][3][4].\n\n### References\n\n[1] A\n[2] B\n[3] C\n[4] D"
sources = [make_source(c, 2020) for c in "ABCDEF"] # 6 retrieved
canon_text, canon_sources = finalize_with_canonical(text, sources)
body_nums = sorted(set(_body_numbers(canon_text)))
self.assertEqual(len(body_nums), len(canon_sources))
self.assertEqual(body_nums, list(range(1, len(canon_sources) + 1)))
class BodyStrippingChainTests(unittest.TestCase):
"""Test the finalize_with_canonical β†’ parse_document β†’ .body chain
used by main.py to strip the References section before SSE."""
def test_body_only_has_no_references_section(self):
text = "Body [1] then [2].\n\n### References\n\n[1] Smith (2020).\n\n[2] Jones (2021)."
sources = [make_source("A", 2020), make_source("B", 2021)]
canon_text, canon_sources = finalize_with_canonical(text, sources)
body = parse_document(canon_text).body.strip()
self.assertNotIn("References", body)
self.assertNotIn("[1] Smith", body)
self.assertIn("[1]", body) # inline citations preserved
self.assertIn("[2]", body)
def test_body_preserves_inline_citations(self):
text = "Claim [2] supported by evidence [1].\n\n### References\n\n[1] Ref.\n\n[2] Ref."
sources = [make_source("A", 2020), make_source("B", 2021)]
canon_text, _ = finalize_with_canonical(text, sources)
body = parse_document(canon_text).body.strip()
self.assertIn("[1]", body)
self.assertIn("[2]", body)
def test_canonical_sources_match_body_order(self):
text = "First [3] then [1].\n\n### References\n\n[1] Ref1.\n\n[3] Ref3."
sources = [make_source("A", 2020), make_source("B", 2021), make_source("C", 2022)]
_, canon_sources = finalize_with_canonical(text, sources)
# [3]=C appeared first in body β†’ canon_sources[0] = C
self.assertEqual(canon_sources[0]["title"], "C")
self.assertEqual(canon_sources[1]["title"], "A")
def test_no_sources_returns_body_without_citations(self):
text = "Body [1] and [2]."
canon_text, canon_sources = finalize_with_canonical(text, [])
body = parse_document(canon_text).body.strip()
self.assertNotIn("[1]", body)
self.assertEqual(canon_sources, [])
def test_round_trip_idempotent(self):
text = "Body [1][2].\n\n### References\n\n[1] Smith. (2020).\n\n[2] Jones. (2021)."
sources = [make_source("A", 2020), make_source("B", 2021)]
t1, s1 = finalize_with_canonical(text, sources)
b1 = parse_document(t1).body.strip()
# Second pass on the body-only + canonical sources should be stable
t2, s2 = finalize_with_canonical(b1 + "\n\n### References\n\n" +
"\n\n".join(f"[{i+1}] {s['title']}" for i, s in enumerate(s1)),
s1)
b2 = parse_document(t2).body.strip()
self.assertEqual(b1, b2)
class AdjacentCitationSortTests(unittest.TestCase):
"""Adjacent [N][M] blocks should be sorted ascending (academic convention)."""
def test_reversed_pair_sorted(self):
text = "Some claim [2][1]."
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2]", body)
def test_triple_reversed_sorted(self):
text = "Evidence [3][1][2]."
sources = [make_source(f"S{c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2][3]", body)
def test_already_sorted_unchanged(self):
text = "Evidence [1][2][3]."
sources = [make_source(f"S{c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2][3]", body)
def test_duplicates_in_adjacent_deduped_and_sorted(self):
text = "Claim [3][1][3][2]."
sources = [make_source(f"S{c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1][2][3]", body)
def test_separate_groups_sorted_independently(self):
text = "First [2][1] and second [3][1]."
sources = [make_source(f"S{c}", 2020) for c in "ABC"]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
# After renumber: orig 2β†’1, orig 1β†’2, orig 3β†’3
# First group: [1][2] (sorted). Second: [3][2] β†’ [2][3] (sorted).
self.assertIn("[1][2]", body)
self.assertIn("[2][3]", body)
def test_single_citation_unchanged(self):
text = "Claim [1]."
sources = [make_source("A", 2020)]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
self.assertIn("[1]", body)
def test_non_adjacent_not_merged(self):
"""[2] some text [1] β€” not adjacent, so no sorting."""
text = "First [2] then some words [1]."
sources = [make_source("A", 2020), make_source("B", 2021)]
result = apply_finalization(text, sources)
body = result.split("### References")[0]
# After renumbering: [2] appeared first β†’ [1], [1] appeared second β†’ [2]
nums = _body_numbers(result)
self.assertEqual(nums, [1, 2])
if __name__ == "__main__":
unittest.main()