File size: 1,805 Bytes
75db650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""Citation helpers (numbered citations like [1], [2], ...)."""

from __future__ import annotations

import re
from typing import Any, Dict, List, Tuple


class CitationManager:
    def __init__(self, *, max_content_length: int = 900):
        self.max_content_length = int(max_content_length)
        self.documents: List[Dict[str, Any]] = []
        self.doc_id_to_index: Dict[str, int] = {}

    def clear(self) -> None:
        self.documents = []
        self.doc_id_to_index = {}

    def add_document(self, document: Dict[str, Any]) -> int:
        doc_id = document.get("doc_id") or ""
        if doc_id in self.doc_id_to_index:
            return self.doc_id_to_index[doc_id]
        self.documents.append(document)
        idx = len(self.documents)
        self.doc_id_to_index[doc_id] = idx
        return idx

    def add_documents(self, documents: List[Dict[str, Any]]) -> List[int]:
        return [self.add_document(d) for d in documents]

    @staticmethod
    def parse_citations_in_text(text: str) -> List[int]:
        matches = re.findall(r"\[(\d+)\]", text or "")
        out = []
        for m in matches:
            try:
                out.append(int(m))
            except Exception:
                continue
        return out

    def validate_citations(self, text: str) -> Tuple[bool, List[int]]:
        cited = self.parse_citations_in_text(text or "")
        invalid = [i for i in cited if i < 1 or i > len(self.documents)]
        return (len(invalid) == 0), invalid

    def get_statistics(self) -> Dict[str, Any]:
        counts: Dict[str, int] = {}
        for d in self.documents:
            st = d.get("source_type", "unknown") or "unknown"
            counts[st] = counts.get(st, 0) + 1
        return {"total": len(self.documents), "source_type_counts": counts}