Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import re | |
| from dataclasses import dataclass | |
| from src.utils.agent_utils import DOC_REF_RE | |
| from src.schemas import ChatCitation, ChatCitationSegment, ChatMessageSegment, ChatTextSegment | |
| from src.services.blog import citation_repository | |
| DOC_REF_CANDIDATE_RE = re.compile(r"<doc-ref\b[^>]*\/?>") | |
| class ParsedCitationResponse: | |
| content: str | |
| segments: list[ChatMessageSegment] | |
| citations: list[ChatCitation] | |
| class CitationTagStreamFilter: | |
| def __init__(self) -> None: | |
| self.pending = "" | |
| self.inside_tag = False | |
| def feed(self, delta: str) -> str: | |
| self.pending += delta | |
| visible_parts: list[str] = [] | |
| while self.pending: | |
| if self.inside_tag: | |
| end_index = self.pending.find("/>") | |
| if end_index == -1: | |
| break | |
| self.pending = self.pending[end_index + 2 :] | |
| self.inside_tag = False | |
| continue | |
| lt_index = self.pending.find("<") | |
| if lt_index == -1: | |
| visible_parts.append(self.pending) | |
| self.pending = "" | |
| break | |
| if lt_index > 0: | |
| visible_parts.append(self.pending[:lt_index]) | |
| self.pending = self.pending[lt_index:] | |
| continue | |
| if self.pending.startswith("<doc-ref"): | |
| self.inside_tag = True | |
| continue | |
| if "<doc-ref".startswith(self.pending): | |
| break | |
| visible_parts.append(self.pending[0]) | |
| self.pending = self.pending[1:] | |
| return "".join(visible_parts) | |
| def flush(self) -> str: | |
| if self.inside_tag: | |
| leftover = "" | |
| else: | |
| leftover = self.pending | |
| self.pending = "" | |
| self.inside_tag = False | |
| return leftover | |
| def parse_citation_segments(raw_text: str, *, allowed_document_ids: set[str]) -> ParsedCitationResponse: | |
| ordered_doc_ids: list[str] = [] | |
| for match in DOC_REF_RE.finditer(raw_text): | |
| document_id = match.group("id") | |
| if document_id not in allowed_document_ids: | |
| continue | |
| if document_id not in ordered_doc_ids: | |
| ordered_doc_ids.append(document_id) | |
| citations = citation_repository.resolve(ordered_doc_ids) | |
| citations_by_id = {citation.document_id: citation for citation in citations} | |
| segments: list[ChatMessageSegment] = [] | |
| visible_parts: list[str] = [] | |
| last_end = 0 | |
| for match in DOC_REF_CANDIDATE_RE.finditer(raw_text): | |
| if match.start() > last_end: | |
| text = raw_text[last_end : match.start()] | |
| segments.append(ChatTextSegment(text=text)) | |
| visible_parts.append(text) | |
| valid_match = DOC_REF_RE.fullmatch(match.group(0)) | |
| if valid_match: | |
| citation = citations_by_id.get(valid_match.group("id")) | |
| if citation is not None: | |
| segments.append(ChatCitationSegment(citation=citation)) | |
| last_end = match.end() | |
| if last_end < len(raw_text): | |
| text = DOC_REF_CANDIDATE_RE.sub("", raw_text[last_end:]) | |
| segments.append(ChatTextSegment(text=text)) | |
| visible_parts.append(text) | |
| return ParsedCitationResponse(content="".join(visible_parts).strip(), segments=segments, citations=citations) | |