chatbot-rag-fi / src /services /citations.py
ABAO77's picture
Upload 147 files
0df80b4 verified
from __future__ import annotations
import re
from dataclasses import dataclass
from src.utils.agent_utils import DOC_REF_RE
from src.schemas import ChatCitation, ChatCitationSegment, ChatMessageSegment, ChatTextSegment
from src.services.blog import citation_repository
DOC_REF_CANDIDATE_RE = re.compile(r"<doc-ref\b[^>]*\/?>")
@dataclass(slots=True)
class ParsedCitationResponse:
content: str
segments: list[ChatMessageSegment]
citations: list[ChatCitation]
class CitationTagStreamFilter:
def __init__(self) -> None:
self.pending = ""
self.inside_tag = False
def feed(self, delta: str) -> str:
self.pending += delta
visible_parts: list[str] = []
while self.pending:
if self.inside_tag:
end_index = self.pending.find("/>")
if end_index == -1:
break
self.pending = self.pending[end_index + 2 :]
self.inside_tag = False
continue
lt_index = self.pending.find("<")
if lt_index == -1:
visible_parts.append(self.pending)
self.pending = ""
break
if lt_index > 0:
visible_parts.append(self.pending[:lt_index])
self.pending = self.pending[lt_index:]
continue
if self.pending.startswith("<doc-ref"):
self.inside_tag = True
continue
if "<doc-ref".startswith(self.pending):
break
visible_parts.append(self.pending[0])
self.pending = self.pending[1:]
return "".join(visible_parts)
def flush(self) -> str:
if self.inside_tag:
leftover = ""
else:
leftover = self.pending
self.pending = ""
self.inside_tag = False
return leftover
def parse_citation_segments(raw_text: str, *, allowed_document_ids: set[str]) -> ParsedCitationResponse:
ordered_doc_ids: list[str] = []
for match in DOC_REF_RE.finditer(raw_text):
document_id = match.group("id")
if document_id not in allowed_document_ids:
continue
if document_id not in ordered_doc_ids:
ordered_doc_ids.append(document_id)
citations = citation_repository.resolve(ordered_doc_ids)
citations_by_id = {citation.document_id: citation for citation in citations}
segments: list[ChatMessageSegment] = []
visible_parts: list[str] = []
last_end = 0
for match in DOC_REF_CANDIDATE_RE.finditer(raw_text):
if match.start() > last_end:
text = raw_text[last_end : match.start()]
segments.append(ChatTextSegment(text=text))
visible_parts.append(text)
valid_match = DOC_REF_RE.fullmatch(match.group(0))
if valid_match:
citation = citations_by_id.get(valid_match.group("id"))
if citation is not None:
segments.append(ChatCitationSegment(citation=citation))
last_end = match.end()
if last_end < len(raw_text):
text = DOC_REF_CANDIDATE_RE.sub("", raw_text[last_end:])
segments.append(ChatTextSegment(text=text))
visible_parts.append(text)
return ParsedCitationResponse(content="".join(visible_parts).strip(), segments=segments, citations=citations)