| """Structured citation document model. |
| |
| Replaces ad-hoc regex juggling with a clean parse β manipulate β render flow: |
| |
| doc = parse_document(messy_text) # text β structured CitationDocument |
| doc = doc.normalize(sources) # drop phantom, renumber sequentially |
| final_text = doc.render(sources) # structured β canonical text |
| |
| The regex layer is *only* used at parse time; all manipulation happens on the |
| in-memory data structure, so edge cases are easy to reason about and test. |
| |
| Key invariants enforced by the model: |
| * Citation numbers are integers, never strings or ranges. |
| * After normalize(N), only numbers in 1..N appear, in first-appearance order. |
| * The references list is rebuilt from `sources` so it always matches the |
| inline citations exactly β no phantom or missing entries. |
| """ |
| from __future__ import annotations |
|
|
| import re |
| from dataclasses import dataclass, field |
| from typing import Iterable, Optional |
|
|
|
|
| |
|
|
| |
| |
| _CITATION_BLOCK_RE = re.compile(r"\[((?:\s*\d+\s*(?:[-,;]\s*\d+\s*)*))\]") |
|
|
| |
| |
| |
| _ESCAPED_BRACKET_RE = re.compile(r"\\(\[|\])") |
|
|
| |
| |
| _REFERENCES_HEADING_RE = re.compile( |
| r"(?im)^(?:#{1,6}\s*)?references?\s*:?\s*$" |
| ) |
|
|
| |
| _REFERENCE_ENTRY_RE = re.compile(r"^\s*\[(\d+)\]\s*(.+?)\s*$", re.S) |
|
|
| _HEADING_RE = re.compile(r"^#{1,6}\s") |
| _FENCE_RE = re.compile(r"^\s*```") |
|
|
|
|
| |
|
|
| @dataclass |
| class Reference: |
| """A single entry from the references section.""" |
| number: int |
| text: str |
|
|
|
|
| @dataclass |
| class CitationDocument: |
| """Parsed answer document with body and references separated. |
| |
| The body still contains the inline `[N]` markers; we don't substitute them |
| out because that would lose positional information. Manipulation methods |
| (`normalize`, `render`) operate on this representation. |
| """ |
| body: str |
| references: list[Reference] = field(default_factory=list) |
|
|
| |
|
|
| def cited_numbers(self) -> list[int]: |
| """All citation numbers in body, in order, with duplicates.""" |
| nums: list[int] = [] |
| for match in _CITATION_BLOCK_RE.finditer(self.body): |
| nums.extend(_parse_citation_inner(match.group(1))) |
| return nums |
|
|
| def first_appearance_order(self) -> list[int]: |
| """Citation numbers in order of first appearance, deduped.""" |
| seen: set[int] = set() |
| ordered: list[int] = [] |
| for num in self.cited_numbers(): |
| if num not in seen: |
| seen.add(num) |
| ordered.append(num) |
| return ordered |
|
|
| |
|
|
| def normalize(self, source_count: int) -> "CitationDocument": |
| """Drop out-of-range citations and renumber to 1..N in first-appearance order. |
| |
| After normalize: |
| * No citation number exceeds `source_count` or is <= 0 |
| * Numbers run 1..K where K is the count of unique surviving citations |
| * The first cited source maps to [1], the second new one to [2], etc. |
| """ |
| if source_count <= 0: |
| return CitationDocument(body=_strip_all_citations(self.body), references=[]) |
|
|
| |
| ordered_old: list[int] = [] |
| seen: set[int] = set() |
| for num in self.cited_numbers(): |
| if 0 < num <= source_count and num not in seen: |
| seen.add(num) |
| ordered_old.append(num) |
|
|
| if not ordered_old: |
| return CitationDocument(body=_strip_all_citations(self.body), references=[]) |
|
|
| |
| renumber = {old: new for new, old in enumerate(ordered_old, 1)} |
|
|
| def _replace(match: re.Match) -> str: |
| new_nums: list[int] = [] |
| seen_in_block: set[int] = set() |
| for old in _parse_citation_inner(match.group(1)): |
| new_num = renumber.get(old) |
| if new_num and new_num not in seen_in_block: |
| seen_in_block.add(new_num) |
| new_nums.append(new_num) |
| if not new_nums: |
| return "" |
| return "".join(f"[{n}]" for n in new_nums) |
|
|
| new_body = _CITATION_BLOCK_RE.sub(_replace, self.body) |
| new_body = _cleanup_punctuation_after_citation_removal(new_body) |
| new_body = _sort_adjacent_citations(new_body) |
|
|
| return CitationDocument(body=new_body, references=[]) |
|
|
| def backfill_uncited_paragraphs(self) -> "CitationDocument": |
| """Add a citation to substantive paragraphs that lack one. |
| |
| Borrows the citation block from the next neighbouring paragraph that |
| already has citations (or the previous one if no future neighbour |
| exists). Only acts on paragraphs with >=8 words that are not headings, |
| blockquotes, or list items. |
| """ |
| parts = re.split(r"(\n\s*\n)", self.body) |
| paragraphs = parts[::2] |
| separators = parts[1::2] |
| para_cites = [ |
| _ordered_unique_in_block(p) for p in paragraphs |
| ] |
| used: list[int] = [] |
| seen: set[int] = set() |
| for cites in para_cites: |
| for n in cites: |
| if n not in seen: |
| seen.add(n) |
| used.append(n) |
|
|
| for idx, paragraph in enumerate(paragraphs): |
| if not _is_substantive_paragraph(paragraph) or para_cites[idx]: |
| continue |
| donor: list[int] = [] |
| for future in para_cites[idx + 1:]: |
| if future: |
| donor = future |
| break |
| if not donor: |
| for past in reversed(para_cites[:idx]): |
| if past: |
| donor = past |
| break |
| if not donor: |
| donor = used[:2] |
| if not donor: |
| continue |
| suffix = "".join(f"[{n}]" for n in donor[:2]) |
| paragraphs[idx] = paragraph.rstrip() + " " + suffix |
| para_cites[idx] = donor[:2] |
|
|
| rebuilt: list[str] = [] |
| for idx, p in enumerate(paragraphs): |
| rebuilt.append(p) |
| if idx < len(separators): |
| rebuilt.append(separators[idx]) |
| return CitationDocument(body="".join(rebuilt).strip(), references=self.references) |
|
|
| |
|
|
| def render(self, sources: list[dict], *, ref_builder=None) -> str: |
| """Emit canonical text: body + freshly built References section. |
| |
| `sources` must be the underlying source list whose order matches the |
| old citation numbers BEFORE this document was normalized. After |
| normalize() the body uses 1..K, but the *meaning* of [1] is "the source |
| that was first cited", so we need to know the mapping. |
| |
| For convenience, callers can use the helper `apply_finalization` below |
| which combines normalize() and render() in a single call. |
| """ |
| if not self.body.strip(): |
| return "" |
| body = self.body.strip() |
| builder = ref_builder or _default_ref_builder |
| ref_text = builder(self.references, sources) |
| if not ref_text: |
| return body |
| return body.rstrip() + "\n\n" + ref_text |
|
|
|
|
| |
|
|
| def parse_document(text: str) -> CitationDocument: |
| """Parse a raw answer string into a CitationDocument. |
| |
| Handles markdown-escaped brackets and stale references. The body retains |
| its inline `[N]` markers; references are extracted into structured entries. |
| """ |
| if not text: |
| return CitationDocument(body="", references=[]) |
|
|
| cleaned = _ESCAPED_BRACKET_RE.sub(r"\1", text) |
| body, ref_block = _split_references_section(cleaned) |
| references = _parse_reference_entries(ref_block) |
| return CitationDocument(body=body.strip(), references=references) |
|
|
|
|
| def apply_finalization(text: str, sources: list[dict]) -> str: |
| """End-to-end pipeline: parse β normalize β backfill β render with sources. |
| |
| Returns the final answer string ready for the frontend. This is the main |
| entry point used by the production pipeline. |
| |
| NOTE: This call interprets [N] in the body as an index into `sources` |
| (1-based). After one call, the output body uses canonical 1..K numbering |
| where 1 = first cited in body order. Calling this function a SECOND time |
| on the output with the SAME `sources` list will produce a different result |
| because the body's [1] now means "first in body order" rather than |
| "sources[0]". |
| |
| To re-finalize already-canonical text, use `apply_finalization(out, |
| canonical_sources)` where `canonical_sources` is the source list reordered |
| to match the body's first-appearance order. See `finalize_with_canonical`. |
| """ |
| if not text: |
| return "" |
| if not sources: |
| |
| doc = parse_document(text) |
| cleaned_body = _strip_all_citations(doc.body) |
| return cleaned_body.strip() |
|
|
| doc = parse_document(text) |
| ordered_old = doc.first_appearance_order_within(len(sources)) |
| if not ordered_old: |
| |
| return _strip_all_citations(doc.body).strip() |
|
|
| normalized = doc.normalize(len(sources)) |
| backfilled = normalized.backfill_uncited_paragraphs() |
|
|
| |
| |
| |
| ref_entries = [ |
| Reference(number=new_num, text=_render_apa(sources[old - 1])) |
| for new_num, old in enumerate(ordered_old, 1) |
| if 0 < old <= len(sources) |
| ] |
| backfilled.references = ref_entries |
| return backfilled.render(sources, ref_builder=_render_reference_block) |
|
|
|
|
| def finalize_with_canonical(text: str, sources: list[dict]) -> tuple[str, list[dict]]: |
| """Like `apply_finalization`, but also returns the source list reordered to |
| match the body's first-appearance order. |
| |
| Calling this function repeatedly with its own (text, canonical_sources) |
| output is a no-op (true idempotence). |
| """ |
| if not text or not sources: |
| return apply_finalization(text, sources), list(sources or []) |
|
|
| doc = parse_document(text) |
| ordered_old = doc.first_appearance_order_within(len(sources)) |
| if not ordered_old: |
| return _strip_all_citations(doc.body).strip(), [] |
|
|
| |
| canonical_sources = [sources[old - 1] for old in ordered_old] |
| final_text = apply_finalization(text, sources) |
| return final_text, canonical_sources |
|
|
|
|
| |
| |
| def _first_appearance_order_within(self: CitationDocument, source_count: int) -> list[int]: |
| """First-appearance order, but only counting numbers within 1..source_count.""" |
| seen: set[int] = set() |
| ordered: list[int] = [] |
| for num in self.cited_numbers(): |
| if 0 < num <= source_count and num not in seen: |
| seen.add(num) |
| ordered.append(num) |
| return ordered |
|
|
|
|
| CitationDocument.first_appearance_order_within = _first_appearance_order_within |
|
|
|
|
| |
|
|
| def _parse_citation_inner(block: str) -> list[int]: |
| """Split a citation block's inner text into integers. |
| |
| Handles ranges (`1-3`), comma/semicolon separated lists (`1, 2`), and |
| plain single numbers (`1`). Bogus tokens are ignored. |
| """ |
| nums: list[int] = [] |
| for part in re.split(r"[;,]", block): |
| part = part.strip() |
| if not part: |
| continue |
| range_match = re.fullmatch(r"(\d+)\s*-\s*(\d+)", part) |
| if range_match: |
| start, end = int(range_match.group(1)), int(range_match.group(2)) |
| if start <= end and (end - start) <= 20: |
| nums.extend(range(start, end + 1)) |
| else: |
| nums.extend([start, end]) |
| continue |
| if part.isdigit(): |
| nums.append(int(part)) |
| return nums |
|
|
|
|
| def _ordered_unique_in_block(text: str) -> list[int]: |
| """All unique citation numbers in a paragraph, in first-appearance order.""" |
| seen: set[int] = set() |
| ordered: list[int] = [] |
| for match in _CITATION_BLOCK_RE.finditer(text): |
| for num in _parse_citation_inner(match.group(1)): |
| if num not in seen: |
| seen.add(num) |
| ordered.append(num) |
| return ordered |
|
|
|
|
| def _strip_all_citations(body: str) -> str: |
| """Remove all `[N]` citation blocks from body.""" |
| cleaned = _CITATION_BLOCK_RE.sub("", body) |
| return _cleanup_punctuation_after_citation_removal(cleaned) |
|
|
|
|
| |
| _ADJACENT_CITATION_RUN_RE = re.compile(r"(\[\d+\])(?:\s*\[\d+\])+") |
|
|
|
|
| def _sort_adjacent_citations(text: str) -> str: |
| """Sort runs of adjacent [N][M] blocks into ascending order. |
| |
| Academic convention puts adjacent citations in ascending order: |
| [1][3][2] β [1][2][3]. Each run of adjacent single-number brackets |
| is detected as a unit and the numbers within it are sorted. |
| """ |
| def _sort_run(match: re.Match) -> str: |
| nums = [int(m.group(1)) for m in re.finditer(r"\[(\d+)\]", match.group(0))] |
| return "".join(f"[{n}]" for n in sorted(set(nums))) |
| return _ADJACENT_CITATION_RUN_RE.sub(_sort_run, text) |
|
|
|
|
| def _cleanup_punctuation_after_citation_removal(text: str) -> str: |
| """Tidy up double spaces and stranded punctuation left behind.""" |
| text = re.sub(r"[ \t]{2,}", " ", text) |
| text = re.sub(r"\s+([,.;:])", r"\1", text) |
| text = re.sub(r"\n{3,}", "\n\n", text) |
| |
| text = re.sub(r"[ \t]+\n", "\n", text) |
| return text |
|
|
|
|
| def _split_references_section(text: str) -> tuple[str, str]: |
| match = _REFERENCES_HEADING_RE.search(text) |
| if not match: |
| return text.rstrip(), "" |
| body = text[: match.start()].rstrip() |
| refs = text[match.end():].strip() |
| return body, refs |
|
|
|
|
| def _parse_reference_entries(refs_block: str) -> list[Reference]: |
| if not refs_block: |
| return [] |
| entries: list[Reference] = [] |
| |
| parts = re.split(r"\n\s*\n", refs_block.strip()) |
| for part in parts: |
| match = _REFERENCE_ENTRY_RE.match(part.strip()) |
| if match: |
| entries.append(Reference(number=int(match.group(1)), text=match.group(2).strip())) |
| return entries |
|
|
|
|
| def _is_substantive_paragraph(paragraph: str) -> bool: |
| stripped = paragraph.strip() |
| if not stripped: |
| return False |
| if _HEADING_RE.match(stripped): |
| return False |
| if stripped.startswith((">", "-", "*")): |
| return False |
| if _FENCE_RE.match(stripped): |
| return False |
| return len(re.findall(r"\b\w+\b", stripped)) >= 8 |
|
|
|
|
| |
|
|
| def _render_apa(source: dict) -> str: |
| """Render a single source as an APA-style reference string.""" |
| authors = source.get("authors", "") |
| year = source.get("year") or source.get("publication_year") or "n.d." |
| title = source.get("title", "Untitled") |
| journal = source.get("journal", "") |
| doi = (source.get("doi") or "").replace("https://doi.org/", "") |
|
|
| parts: list[str] = [] |
| if authors: |
| parts.append(f"{authors}.") |
| parts.append(f"*{title}*") |
| parts.append(f"({year}).") |
| if journal: |
| parts.append(f"*{journal}*.") |
| if doi: |
| parts.append(f"DOI: {doi}") |
| return " ".join(parts) |
|
|
|
|
| def _render_reference_block(refs: Iterable[Reference], sources: list[dict]) -> str: |
| lines = [f"[{r.number}] {r.text}" for r in refs] |
| if not lines: |
| return "" |
| return "### References\n\n" + "\n\n".join(lines) |
|
|
|
|
| def _default_ref_builder(refs: Iterable[Reference], sources: list[dict]) -> str: |
| return _render_reference_block(refs, sources) |
|
|
|
|
| |
|
|
| def has_phantom_citations(text: str, source_count: int) -> bool: |
| """True if any inline citation number exceeds the available source count.""" |
| doc = parse_document(text) |
| return any(n > source_count or n <= 0 for n in doc.cited_numbers()) |
|
|
|
|
| def has_grouped_or_range_syntax(text: str) -> bool: |
| """True if body has [1, 2] or [1-2] style citation blocks (not [1][2]).""" |
| body = parse_document(text).body |
| for match in _CITATION_BLOCK_RE.finditer(body): |
| inner = match.group(1) |
| if "," in inner or ";" in inner or "-" in inner: |
| return True |
| return False |
|
|
|
|
| def has_escaped_brackets(text: str) -> bool: |
| """True if the original text contains markdown-escaped brackets \\[ or \\].""" |
| return bool(_ESCAPED_BRACKET_RE.search(text or "")) |
|
|
|
|
| def is_sequential_in_first_appearance_order(text: str) -> bool: |
| """True if the first occurrence order is exactly 1..N (no out-of-order numbering).""" |
| ordered = parse_document(text).first_appearance_order() |
| return ordered == list(range(1, len(ordered) + 1)) |
|
|
|
|
| def body_and_references_match(text: str) -> bool: |
| """True if the set of body citation numbers equals the set in references.""" |
| doc = parse_document(text) |
| body_set = set(doc.cited_numbers()) |
| ref_set = {r.number for r in doc.references} |
| return body_set == ref_set |
|
|