File size: 3,599 Bytes
482c2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
extractor.py — Extract EvaluationSpans from an annotated or gold-standard XML element.

A span is positioned by its character offsets in the element's *plain text*
(i.e. all tags stripped, just the text content concatenated in document order).
This makes gold and predicted spans directly comparable as long as both are
derived from the same source text.
"""

from __future__ import annotations

from dataclasses import dataclass, field

from lxml import etree


def _strip_ns(tag: str) -> str:
    """Remove Clark-notation namespace from '{http://...}tag' → 'tag'."""
    if tag.startswith("{"):
        return tag.split("}", 1)[1]
    return tag


@dataclass
class EvaluationSpan:
    """A span extracted from XML, positioned in the element's plain text."""

    element: str            # tag name (namespace stripped)
    start: int              # char offset in stripped plain text (inclusive)
    end: int                # char offset in stripped plain text (exclusive)
    text: str               # the spanned text content
    attrs: dict[str, str] = field(default_factory=dict)

    @property
    def normalized_text(self) -> str:
        """Collapse internal whitespace for lenient comparison."""
        return " ".join(self.text.split())


def _extract_recursive(
    element: etree._Element,
    spans: list[EvaluationSpan],
    offset: int,
) -> int:
    """
    Walk *element*'s subtree depth-first, appending an EvaluationSpan for
    every descendant element to *spans*.

    Returns the offset *after* all of element's content (excluding its tail,
    since tail text belongs to the *parent*, not to this element).
    """
    current = offset

    # Text that belongs to this element, before its first child
    if element.text:
        current += len(element.text)

    for child in element:
        child_start = current
        # Recurse: sets current to the end of all content inside child
        current = _extract_recursive(child, spans, current)
        child_end = current

        child_text = "".join(child.itertext())
        spans.append(
            EvaluationSpan(
                element=_strip_ns(child.tag),
                start=child_start,
                end=child_end,
                text=child_text,
                attrs={_strip_ns(k): v for k, v in child.attrib.items()},
            )
        )

        # Tail text belongs to the parent, not the child
        if child.tail:
            current += len(child.tail)

    return current


def extract_spans(element: etree._Element) -> tuple[str, list[EvaluationSpan]]:
    """
    Extract plain text and EvaluationSpans from an lxml element.

    The plain text is ``"".join(element.itertext())``.  Every descendant
    element (at any depth) produces one EvaluationSpan whose ``start``/``end``
    are byte-exact offsets into that plain text.  Nested elements produce
    overlapping (container, child) span pairs — that is fine for matching.

    Returns
    -------
    (plain_text, spans)
        ``plain_text`` is the concatenated text content.
        ``spans`` is a flat list ordered depth-first (children before parents).
    """
    spans: list[EvaluationSpan] = []
    _extract_recursive(element, spans, offset=0)
    plain_text = "".join(element.itertext())
    return plain_text, spans


def spans_from_xml_string(xml_str: str) -> tuple[str, list[EvaluationSpan]]:
    """
    Parse an XML string with a single root element and extract spans.

    Convenience wrapper around :func:`extract_spans`.
    """
    root = etree.fromstring(xml_str.encode())
    return extract_spans(root)