from __future__ import annotations import re from dataclasses import dataclass from src.utils.agent_utils import DOC_REF_RE from src.schemas import ChatCitation, ChatCitationSegment, ChatMessageSegment, ChatTextSegment from src.services.blog import citation_repository DOC_REF_CANDIDATE_RE = re.compile(r"]*\/?>") @dataclass(slots=True) class ParsedCitationResponse: content: str segments: list[ChatMessageSegment] citations: list[ChatCitation] class CitationTagStreamFilter: def __init__(self) -> None: self.pending = "" self.inside_tag = False def feed(self, delta: str) -> str: self.pending += delta visible_parts: list[str] = [] while self.pending: if self.inside_tag: end_index = self.pending.find("/>") if end_index == -1: break self.pending = self.pending[end_index + 2 :] self.inside_tag = False continue lt_index = self.pending.find("<") if lt_index == -1: visible_parts.append(self.pending) self.pending = "" break if lt_index > 0: visible_parts.append(self.pending[:lt_index]) self.pending = self.pending[lt_index:] continue if self.pending.startswith(" str: if self.inside_tag: leftover = "" else: leftover = self.pending self.pending = "" self.inside_tag = False return leftover def parse_citation_segments(raw_text: str, *, allowed_document_ids: set[str]) -> ParsedCitationResponse: ordered_doc_ids: list[str] = [] for match in DOC_REF_RE.finditer(raw_text): document_id = match.group("id") if document_id not in allowed_document_ids: continue if document_id not in ordered_doc_ids: ordered_doc_ids.append(document_id) citations = citation_repository.resolve(ordered_doc_ids) citations_by_id = {citation.document_id: citation for citation in citations} segments: list[ChatMessageSegment] = [] visible_parts: list[str] = [] last_end = 0 for match in DOC_REF_CANDIDATE_RE.finditer(raw_text): if match.start() > last_end: text = raw_text[last_end : match.start()] segments.append(ChatTextSegment(text=text)) visible_parts.append(text) valid_match = DOC_REF_RE.fullmatch(match.group(0)) if valid_match: citation = citations_by_id.get(valid_match.group("id")) if citation is not None: segments.append(ChatCitationSegment(citation=citation)) last_end = match.end() if last_end < len(raw_text): text = DOC_REF_CANDIDATE_RE.sub("", raw_text[last_end:]) segments.append(ChatTextSegment(text=text)) visible_parts.append(text) return ParsedCitationResponse(content="".join(visible_parts).strip(), segments=segments, citations=citations)