File size: 3,380 Bytes
0df80b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations

import re
from dataclasses import dataclass

from src.utils.agent_utils import DOC_REF_RE
from src.schemas import ChatCitation, ChatCitationSegment, ChatMessageSegment, ChatTextSegment
from src.services.blog import citation_repository

DOC_REF_CANDIDATE_RE = re.compile(r"<doc-ref\b[^>]*\/?>")


@dataclass(slots=True)
class ParsedCitationResponse:
    content: str
    segments: list[ChatMessageSegment]
    citations: list[ChatCitation]


class CitationTagStreamFilter:
    def __init__(self) -> None:
        self.pending = ""
        self.inside_tag = False

    def feed(self, delta: str) -> str:
        self.pending += delta
        visible_parts: list[str] = []

        while self.pending:
            if self.inside_tag:
                end_index = self.pending.find("/>")
                if end_index == -1:
                    break
                self.pending = self.pending[end_index + 2 :]
                self.inside_tag = False
                continue

            lt_index = self.pending.find("<")
            if lt_index == -1:
                visible_parts.append(self.pending)
                self.pending = ""
                break

            if lt_index > 0:
                visible_parts.append(self.pending[:lt_index])
                self.pending = self.pending[lt_index:]
                continue

            if self.pending.startswith("<doc-ref"):
                self.inside_tag = True
                continue

            if "<doc-ref".startswith(self.pending):
                break

            visible_parts.append(self.pending[0])
            self.pending = self.pending[1:]

        return "".join(visible_parts)

    def flush(self) -> str:
        if self.inside_tag:
            leftover = ""
        else:
            leftover = self.pending
        self.pending = ""
        self.inside_tag = False
        return leftover


def parse_citation_segments(raw_text: str, *, allowed_document_ids: set[str]) -> ParsedCitationResponse:
    ordered_doc_ids: list[str] = []
    for match in DOC_REF_RE.finditer(raw_text):
        document_id = match.group("id")
        if document_id not in allowed_document_ids:
            continue
        if document_id not in ordered_doc_ids:
            ordered_doc_ids.append(document_id)

    citations = citation_repository.resolve(ordered_doc_ids)
    citations_by_id = {citation.document_id: citation for citation in citations}

    segments: list[ChatMessageSegment] = []
    visible_parts: list[str] = []
    last_end = 0
    for match in DOC_REF_CANDIDATE_RE.finditer(raw_text):
        if match.start() > last_end:
            text = raw_text[last_end : match.start()]
            segments.append(ChatTextSegment(text=text))
            visible_parts.append(text)

        valid_match = DOC_REF_RE.fullmatch(match.group(0))
        if valid_match:
            citation = citations_by_id.get(valid_match.group("id"))
            if citation is not None:
                segments.append(ChatCitationSegment(citation=citation))
        last_end = match.end()

    if last_end < len(raw_text):
        text = DOC_REF_CANDIDATE_RE.sub("", raw_text[last_end:])
        segments.append(ChatTextSegment(text=text))
        visible_parts.append(text)

    return ParsedCitationResponse(content="".join(visible_parts).strip(), segments=segments, citations=citations)