crcs-live / citations.py
Nipun's picture
v0.25.0: Align with HTML artifact β€” t=0 counting, Phase 2 tracking, post-stability L1/S1
7b80f8b
"""Structured citation document model.
Replaces ad-hoc regex juggling with a clean parse β†’ manipulate β†’ render flow:
doc = parse_document(messy_text) # text β†’ structured CitationDocument
doc = doc.normalize(sources) # drop phantom, renumber sequentially
final_text = doc.render(sources) # structured β†’ canonical text
The regex layer is *only* used at parse time; all manipulation happens on the
in-memory data structure, so edge cases are easy to reason about and test.
Key invariants enforced by the model:
* Citation numbers are integers, never strings or ranges.
* After normalize(N), only numbers in 1..N appear, in first-appearance order.
* The references list is rebuilt from `sources` so it always matches the
inline citations exactly β€” no phantom or missing entries.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Iterable, Optional
# ── Regex (used only by parse_document) ─────────────────────────────────────
# Matches a single citation block like [1], [1,2], [1-3], [1; 2], or [1][2].
# Captures the inner content so _parse_citation_inner can split into ints.
_CITATION_BLOCK_RE = re.compile(r"\[((?:\s*\d+\s*(?:[-,;]\s*\d+\s*)*))\]")
# Markdown-escaped brackets that some LLMs emit (\[1\]\[2\]). Renderers strip
# the escapes, so the displayed text shows real citations that bypass naive
# regex parsing. We unescape eagerly at the parse boundary.
_ESCAPED_BRACKET_RE = re.compile(r"\\(\[|\])")
# References section heading: case-insensitive, optional markdown header
# prefix, optional trailing colon, requires its own line.
_REFERENCES_HEADING_RE = re.compile(
r"(?im)^(?:#{1,6}\s*)?references?\s*:?\s*$"
)
# Inside the references section, lines that start with [N] followed by content.
_REFERENCE_ENTRY_RE = re.compile(r"^\s*\[(\d+)\]\s*(.+?)\s*$", re.S)
_HEADING_RE = re.compile(r"^#{1,6}\s")
_FENCE_RE = re.compile(r"^\s*```")
# ── Data classes ────────────────────────────────────────────────────────────
@dataclass
class Reference:
"""A single entry from the references section."""
number: int
text: str
@dataclass
class CitationDocument:
"""Parsed answer document with body and references separated.
The body still contains the inline `[N]` markers; we don't substitute them
out because that would lose positional information. Manipulation methods
(`normalize`, `render`) operate on this representation.
"""
body: str
references: list[Reference] = field(default_factory=list)
# ── Citation extraction ────────────────────────────────────────────
def cited_numbers(self) -> list[int]:
"""All citation numbers in body, in order, with duplicates."""
nums: list[int] = []
for match in _CITATION_BLOCK_RE.finditer(self.body):
nums.extend(_parse_citation_inner(match.group(1)))
return nums
def first_appearance_order(self) -> list[int]:
"""Citation numbers in order of first appearance, deduped."""
seen: set[int] = set()
ordered: list[int] = []
for num in self.cited_numbers():
if num not in seen:
seen.add(num)
ordered.append(num)
return ordered
# ── Manipulation ───────────────────────────────────────────────────
def normalize(self, source_count: int) -> "CitationDocument":
"""Drop out-of-range citations and renumber to 1..N in first-appearance order.
After normalize:
* No citation number exceeds `source_count` or is <= 0
* Numbers run 1..K where K is the count of unique surviving citations
* The first cited source maps to [1], the second new one to [2], etc.
"""
if source_count <= 0:
return CitationDocument(body=_strip_all_citations(self.body), references=[])
# Find the order of unique valid old numbers
ordered_old: list[int] = []
seen: set[int] = set()
for num in self.cited_numbers():
if 0 < num <= source_count and num not in seen:
seen.add(num)
ordered_old.append(num)
if not ordered_old:
return CitationDocument(body=_strip_all_citations(self.body), references=[])
# old_num β†’ new_num map
renumber = {old: new for new, old in enumerate(ordered_old, 1)}
def _replace(match: re.Match) -> str:
new_nums: list[int] = []
seen_in_block: set[int] = set()
for old in _parse_citation_inner(match.group(1)):
new_num = renumber.get(old)
if new_num and new_num not in seen_in_block:
seen_in_block.add(new_num)
new_nums.append(new_num)
if not new_nums:
return ""
return "".join(f"[{n}]" for n in new_nums)
new_body = _CITATION_BLOCK_RE.sub(_replace, self.body)
new_body = _cleanup_punctuation_after_citation_removal(new_body)
new_body = _sort_adjacent_citations(new_body)
return CitationDocument(body=new_body, references=[])
def backfill_uncited_paragraphs(self) -> "CitationDocument":
"""Add a citation to substantive paragraphs that lack one.
Borrows the citation block from the next neighbouring paragraph that
already has citations (or the previous one if no future neighbour
exists). Only acts on paragraphs with >=8 words that are not headings,
blockquotes, or list items.
"""
parts = re.split(r"(\n\s*\n)", self.body)
paragraphs = parts[::2]
separators = parts[1::2]
para_cites = [
_ordered_unique_in_block(p) for p in paragraphs
]
used: list[int] = []
seen: set[int] = set()
for cites in para_cites:
for n in cites:
if n not in seen:
seen.add(n)
used.append(n)
for idx, paragraph in enumerate(paragraphs):
if not _is_substantive_paragraph(paragraph) or para_cites[idx]:
continue
donor: list[int] = []
for future in para_cites[idx + 1:]:
if future:
donor = future
break
if not donor:
for past in reversed(para_cites[:idx]):
if past:
donor = past
break
if not donor:
donor = used[:2]
if not donor:
continue
suffix = "".join(f"[{n}]" for n in donor[:2])
paragraphs[idx] = paragraph.rstrip() + " " + suffix
para_cites[idx] = donor[:2]
rebuilt: list[str] = []
for idx, p in enumerate(paragraphs):
rebuilt.append(p)
if idx < len(separators):
rebuilt.append(separators[idx])
return CitationDocument(body="".join(rebuilt).strip(), references=self.references)
# ── Rendering ──────────────────────────────────────────────────────
def render(self, sources: list[dict], *, ref_builder=None) -> str:
"""Emit canonical text: body + freshly built References section.
`sources` must be the underlying source list whose order matches the
old citation numbers BEFORE this document was normalized. After
normalize() the body uses 1..K, but the *meaning* of [1] is "the source
that was first cited", so we need to know the mapping.
For convenience, callers can use the helper `apply_finalization` below
which combines normalize() and render() in a single call.
"""
if not self.body.strip():
return ""
body = self.body.strip()
builder = ref_builder or _default_ref_builder
ref_text = builder(self.references, sources)
if not ref_text:
return body
return body.rstrip() + "\n\n" + ref_text
# ── Module-level helpers ────────────────────────────────────────────────────
def parse_document(text: str) -> CitationDocument:
"""Parse a raw answer string into a CitationDocument.
Handles markdown-escaped brackets and stale references. The body retains
its inline `[N]` markers; references are extracted into structured entries.
"""
if not text:
return CitationDocument(body="", references=[])
cleaned = _ESCAPED_BRACKET_RE.sub(r"\1", text)
body, ref_block = _split_references_section(cleaned)
references = _parse_reference_entries(ref_block)
return CitationDocument(body=body.strip(), references=references)
def apply_finalization(text: str, sources: list[dict]) -> str:
"""End-to-end pipeline: parse β†’ normalize β†’ backfill β†’ render with sources.
Returns the final answer string ready for the frontend. This is the main
entry point used by the production pipeline.
NOTE: This call interprets [N] in the body as an index into `sources`
(1-based). After one call, the output body uses canonical 1..K numbering
where 1 = first cited in body order. Calling this function a SECOND time
on the output with the SAME `sources` list will produce a different result
because the body's [1] now means "first in body order" rather than
"sources[0]".
To re-finalize already-canonical text, use `apply_finalization(out,
canonical_sources)` where `canonical_sources` is the source list reordered
to match the body's first-appearance order. See `finalize_with_canonical`.
"""
if not text:
return ""
if not sources:
# No sources at all β€” strip any inline citations and return body only.
doc = parse_document(text)
cleaned_body = _strip_all_citations(doc.body)
return cleaned_body.strip()
doc = parse_document(text)
ordered_old = doc.first_appearance_order_within(len(sources))
if not ordered_old:
# No usable citations: drop all and return body.
return _strip_all_citations(doc.body).strip()
normalized = doc.normalize(len(sources))
backfilled = normalized.backfill_uncited_paragraphs()
# Build the references list mapped to renumbered slots. After normalize,
# body uses [1..K] where slot k corresponds to ordered_old[k-1] in the
# original source list.
ref_entries = [
Reference(number=new_num, text=_render_apa(sources[old - 1]))
for new_num, old in enumerate(ordered_old, 1)
if 0 < old <= len(sources)
]
backfilled.references = ref_entries
return backfilled.render(sources, ref_builder=_render_reference_block)
def finalize_with_canonical(text: str, sources: list[dict]) -> tuple[str, list[dict]]:
"""Like `apply_finalization`, but also returns the source list reordered to
match the body's first-appearance order.
Calling this function repeatedly with its own (text, canonical_sources)
output is a no-op (true idempotence).
"""
if not text or not sources:
return apply_finalization(text, sources), list(sources or [])
doc = parse_document(text)
ordered_old = doc.first_appearance_order_within(len(sources))
if not ordered_old:
return _strip_all_citations(doc.body).strip(), []
# Reorder sources to match body's first-appearance order
canonical_sources = [sources[old - 1] for old in ordered_old]
final_text = apply_finalization(text, sources)
return final_text, canonical_sources
# Convenience monkey-patch: add to CitationDocument so callers don't need a
# separate helper for the source-count-aware order.
def _first_appearance_order_within(self: CitationDocument, source_count: int) -> list[int]:
"""First-appearance order, but only counting numbers within 1..source_count."""
seen: set[int] = set()
ordered: list[int] = []
for num in self.cited_numbers():
if 0 < num <= source_count and num not in seen:
seen.add(num)
ordered.append(num)
return ordered
CitationDocument.first_appearance_order_within = _first_appearance_order_within # type: ignore[attr-defined]
# ── Internal helpers ────────────────────────────────────────────────────────
def _parse_citation_inner(block: str) -> list[int]:
"""Split a citation block's inner text into integers.
Handles ranges (`1-3`), comma/semicolon separated lists (`1, 2`), and
plain single numbers (`1`). Bogus tokens are ignored.
"""
nums: list[int] = []
for part in re.split(r"[;,]", block):
part = part.strip()
if not part:
continue
range_match = re.fullmatch(r"(\d+)\s*-\s*(\d+)", part)
if range_match:
start, end = int(range_match.group(1)), int(range_match.group(2))
if start <= end and (end - start) <= 20:
nums.extend(range(start, end + 1))
else:
nums.extend([start, end])
continue
if part.isdigit():
nums.append(int(part))
return nums
def _ordered_unique_in_block(text: str) -> list[int]:
"""All unique citation numbers in a paragraph, in first-appearance order."""
seen: set[int] = set()
ordered: list[int] = []
for match in _CITATION_BLOCK_RE.finditer(text):
for num in _parse_citation_inner(match.group(1)):
if num not in seen:
seen.add(num)
ordered.append(num)
return ordered
def _strip_all_citations(body: str) -> str:
"""Remove all `[N]` citation blocks from body."""
cleaned = _CITATION_BLOCK_RE.sub("", body)
return _cleanup_punctuation_after_citation_removal(cleaned)
# Matches a run of adjacent [N] blocks like [2][1][3] (possibly with whitespace).
_ADJACENT_CITATION_RUN_RE = re.compile(r"(\[\d+\])(?:\s*\[\d+\])+")
def _sort_adjacent_citations(text: str) -> str:
"""Sort runs of adjacent [N][M] blocks into ascending order.
Academic convention puts adjacent citations in ascending order:
[1][3][2] β†’ [1][2][3]. Each run of adjacent single-number brackets
is detected as a unit and the numbers within it are sorted.
"""
def _sort_run(match: re.Match) -> str:
nums = [int(m.group(1)) for m in re.finditer(r"\[(\d+)\]", match.group(0))]
return "".join(f"[{n}]" for n in sorted(set(nums)))
return _ADJACENT_CITATION_RUN_RE.sub(_sort_run, text)
def _cleanup_punctuation_after_citation_removal(text: str) -> str:
"""Tidy up double spaces and stranded punctuation left behind."""
text = re.sub(r"[ \t]{2,}", " ", text)
text = re.sub(r"\s+([,.;:])", r"\1", text)
text = re.sub(r"\n{3,}", "\n\n", text)
# Drop spaces before line endings introduced by removing a citation.
text = re.sub(r"[ \t]+\n", "\n", text)
return text
def _split_references_section(text: str) -> tuple[str, str]:
match = _REFERENCES_HEADING_RE.search(text)
if not match:
return text.rstrip(), ""
body = text[: match.start()].rstrip()
refs = text[match.end():].strip()
return body, refs
def _parse_reference_entries(refs_block: str) -> list[Reference]:
if not refs_block:
return []
entries: list[Reference] = []
# Split on blank lines OR new [N] line β€” whichever comes first.
parts = re.split(r"\n\s*\n", refs_block.strip())
for part in parts:
match = _REFERENCE_ENTRY_RE.match(part.strip())
if match:
entries.append(Reference(number=int(match.group(1)), text=match.group(2).strip()))
return entries
def _is_substantive_paragraph(paragraph: str) -> bool:
stripped = paragraph.strip()
if not stripped:
return False
if _HEADING_RE.match(stripped):
return False
if stripped.startswith((">", "-", "*")):
return False
if _FENCE_RE.match(stripped):
return False
return len(re.findall(r"\b\w+\b", stripped)) >= 8
# ── Reference rendering (APA-ish) ───────────────────────────────────────────
def _render_apa(source: dict) -> str:
"""Render a single source as an APA-style reference string."""
authors = source.get("authors", "")
year = source.get("year") or source.get("publication_year") or "n.d."
title = source.get("title", "Untitled")
journal = source.get("journal", "")
doi = (source.get("doi") or "").replace("https://doi.org/", "")
parts: list[str] = []
if authors:
parts.append(f"{authors}.")
parts.append(f"*{title}*")
parts.append(f"({year}).")
if journal:
parts.append(f"*{journal}*.")
if doi:
parts.append(f"DOI: {doi}")
return " ".join(parts)
def _render_reference_block(refs: Iterable[Reference], sources: list[dict]) -> str:
lines = [f"[{r.number}] {r.text}" for r in refs]
if not lines:
return ""
return "### References\n\n" + "\n\n".join(lines)
def _default_ref_builder(refs: Iterable[Reference], sources: list[dict]) -> str:
return _render_reference_block(refs, sources)
# ── Diagnostics (used by tests + audit) ─────────────────────────────────────
def has_phantom_citations(text: str, source_count: int) -> bool:
"""True if any inline citation number exceeds the available source count."""
doc = parse_document(text)
return any(n > source_count or n <= 0 for n in doc.cited_numbers())
def has_grouped_or_range_syntax(text: str) -> bool:
"""True if body has [1, 2] or [1-2] style citation blocks (not [1][2])."""
body = parse_document(text).body
for match in _CITATION_BLOCK_RE.finditer(body):
inner = match.group(1)
if "," in inner or ";" in inner or "-" in inner:
return True
return False
def has_escaped_brackets(text: str) -> bool:
"""True if the original text contains markdown-escaped brackets \\[ or \\]."""
return bool(_ESCAPED_BRACKET_RE.search(text or ""))
def is_sequential_in_first_appearance_order(text: str) -> bool:
"""True if the first occurrence order is exactly 1..N (no out-of-order numbering)."""
ordered = parse_document(text).first_appearance_order()
return ordered == list(range(1, len(ordered) + 1))
def body_and_references_match(text: str) -> bool:
"""True if the set of body citation numbers equals the set in references."""
doc = parse_document(text)
body_set = set(doc.cited_numbers())
ref_set = {r.number for r in doc.references}
return body_set == ref_set