File size: 5,467 Bytes
3bb804c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from __future__ import annotations

import math
import re
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, NamedTuple, TypeVar, Union

import numpy as np

from edfio.edf_signal import BdfSignal, EdfSignal

_ANNOTATIONS_PATTERN = re.compile(
    """
    ([+-]\\d+(?:\\.?\\d+)?)       # onset
    (?:\x15(\\d+(?:\\.?\\d+)?))?  # duration, optional
    (?:\x14(.*?))                 # annotation texts
    \x14\x00                      # terminator
    """,
    re.VERBOSE,
)


def _encode_annotation_onset(onset: float) -> str:
    string = f"{onset:+.12f}".rstrip("0")
    if string[-1] == ".":
        return string[:-1]
    return string


def _encode_annotation_duration(duration: float) -> str:
    if duration < 0:
        raise ValueError(f"Annotation duration must be positive, is {duration}")
    string = f"{duration:.12f}".rstrip("0")
    if string[-1] == ".":
        return string[:-1]
    return string


class EdfAnnotation(NamedTuple):
    """A single EDF+ annotation.

    Parameters
    ----------
    onset : float
        The annotation onset in seconds from recording start.
    duration : float | None
        The annotation duration in seconds (`None` if annotation has no duration).
    text : str
        The annotation text, can be empty.
    """

    onset: float
    duration: float | None
    text: str

    def __lt__(self, other: Any) -> bool:
        if not isinstance(other, EdfAnnotation):
            return NotImplemented  # pragma: no cover
        return (
            self.onset,
            -1 if self.duration is None else self.duration,
            self.text,
        ) < (
            other.onset,
            -1 if other.duration is None else other.duration,
            other.text,
        )


_Signal = TypeVar("_Signal", bound=Union[EdfSignal, BdfSignal])


def _create_annotations_signal(
    annotations: Iterable[EdfAnnotation],
    *,
    num_data_records: int,
    data_record_duration: float,
    signal_class: type[_Signal],
    with_timestamps: bool = True,
    subsecond_offset: float = 0,
) -> _Signal:
    bytes_per_sample = signal_class._bytes_per_sample
    data_record_starts = np.arange(num_data_records) * data_record_duration
    # list.pop() is O(1) and list.pop(0) is O(n), so using a reversed list is faster
    annotations = sorted(annotations, reverse=True)
    data_records = []
    for i, start in enumerate(data_record_starts):
        end = start + data_record_duration
        tals: list[_EdfTAL] = []
        if with_timestamps:
            tals.append(_EdfTAL(start + subsecond_offset, None, [""]))
        while annotations and (
            annotations[-1].onset < end or i == num_data_records - 1
        ):
            ann = annotations.pop()
            tals.append(
                _EdfTAL(
                    ann.onset + subsecond_offset,
                    ann.duration,
                    [ann.text],
                )
            )
        data_records.append(_EdfAnnotationsDataRecord(tals).to_bytes())
    maxlen = max(len(data_record) for data_record in data_records)
    maxlen = math.ceil(maxlen / bytes_per_sample) * bytes_per_sample
    raw = b"".join(dr.ljust(maxlen, b"\x00") for dr in data_records)
    divisor = data_record_duration if data_record_duration else 1
    signal = signal_class(
        np.arange(1.0),  # placeholder signal, as argument `data` is non-optional
        sampling_frequency=maxlen // bytes_per_sample / divisor,
        physical_range=signal_class._default_digital_range,
    )
    signal._label = f"{signal_class._fmt} Annotations ".encode()
    signal._set_samples_per_data_record(maxlen // bytes_per_sample)
    signal._digital = np.frombuffer(raw, dtype=np.uint8).copy()  # type: ignore[assignment]
    return signal  # type: ignore[return-value]


@dataclass
class _EdfTAL:
    onset: float
    duration: float | None
    texts: list[str]

    def to_bytes(self) -> bytes:
        timing = _encode_annotation_onset(self.onset)
        if self.duration is not None:
            timing += f"\x15{_encode_annotation_duration(self.duration)}"
        texts_joined = "\x14".join(self.texts)
        return f"{timing}\x14{texts_joined}\x14".encode()


@dataclass
class _EdfAnnotationsDataRecord:
    tals: list[_EdfTAL]

    def to_bytes(self) -> bytes:
        return b"\x00".join(tal.to_bytes() for tal in self.tals) + b"\x00"

    @classmethod
    def from_bytes(cls, raw: bytes) -> _EdfAnnotationsDataRecord:
        tals: list[_EdfTAL] = []
        matches: list[tuple[str, str, str]] = _ANNOTATIONS_PATTERN.findall(raw.decode())
        if not matches and raw.replace(b"\x00", b""):
            raise ValueError(f"No valid annotations found in {raw!r}")
        for onset, duration, texts in matches:
            tals.append(
                _EdfTAL(
                    float(onset),
                    float(duration) if duration else None,
                    list(texts.split("\x14")),
                )
            )
        return cls(tals)

    @property
    def annotations(self) -> list[EdfAnnotation]:
        return [
            EdfAnnotation(tal.onset, tal.duration, text)
            for tal in self.tals
            for text in tal.texts
        ]

    def drop_annotations_with_text(self, text: str) -> None:
        for tal in self.tals:
            while text in tal.texts:
                tal.texts.remove(text)
        self.tals = [tal for tal in self.tals if tal.texts]