Spaces:
Running
Running
File size: 5,467 Bytes
3bb804c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
from __future__ import annotations
import math
import re
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, NamedTuple, TypeVar, Union
import numpy as np
from edfio.edf_signal import BdfSignal, EdfSignal
_ANNOTATIONS_PATTERN = re.compile(
"""
([+-]\\d+(?:\\.?\\d+)?) # onset
(?:\x15(\\d+(?:\\.?\\d+)?))? # duration, optional
(?:\x14(.*?)) # annotation texts
\x14\x00 # terminator
""",
re.VERBOSE,
)
def _encode_annotation_onset(onset: float) -> str:
string = f"{onset:+.12f}".rstrip("0")
if string[-1] == ".":
return string[:-1]
return string
def _encode_annotation_duration(duration: float) -> str:
if duration < 0:
raise ValueError(f"Annotation duration must be positive, is {duration}")
string = f"{duration:.12f}".rstrip("0")
if string[-1] == ".":
return string[:-1]
return string
class EdfAnnotation(NamedTuple):
"""A single EDF+ annotation.
Parameters
----------
onset : float
The annotation onset in seconds from recording start.
duration : float | None
The annotation duration in seconds (`None` if annotation has no duration).
text : str
The annotation text, can be empty.
"""
onset: float
duration: float | None
text: str
def __lt__(self, other: Any) -> bool:
if not isinstance(other, EdfAnnotation):
return NotImplemented # pragma: no cover
return (
self.onset,
-1 if self.duration is None else self.duration,
self.text,
) < (
other.onset,
-1 if other.duration is None else other.duration,
other.text,
)
_Signal = TypeVar("_Signal", bound=Union[EdfSignal, BdfSignal])
def _create_annotations_signal(
annotations: Iterable[EdfAnnotation],
*,
num_data_records: int,
data_record_duration: float,
signal_class: type[_Signal],
with_timestamps: bool = True,
subsecond_offset: float = 0,
) -> _Signal:
bytes_per_sample = signal_class._bytes_per_sample
data_record_starts = np.arange(num_data_records) * data_record_duration
# list.pop() is O(1) and list.pop(0) is O(n), so using a reversed list is faster
annotations = sorted(annotations, reverse=True)
data_records = []
for i, start in enumerate(data_record_starts):
end = start + data_record_duration
tals: list[_EdfTAL] = []
if with_timestamps:
tals.append(_EdfTAL(start + subsecond_offset, None, [""]))
while annotations and (
annotations[-1].onset < end or i == num_data_records - 1
):
ann = annotations.pop()
tals.append(
_EdfTAL(
ann.onset + subsecond_offset,
ann.duration,
[ann.text],
)
)
data_records.append(_EdfAnnotationsDataRecord(tals).to_bytes())
maxlen = max(len(data_record) for data_record in data_records)
maxlen = math.ceil(maxlen / bytes_per_sample) * bytes_per_sample
raw = b"".join(dr.ljust(maxlen, b"\x00") for dr in data_records)
divisor = data_record_duration if data_record_duration else 1
signal = signal_class(
np.arange(1.0), # placeholder signal, as argument `data` is non-optional
sampling_frequency=maxlen // bytes_per_sample / divisor,
physical_range=signal_class._default_digital_range,
)
signal._label = f"{signal_class._fmt} Annotations ".encode()
signal._set_samples_per_data_record(maxlen // bytes_per_sample)
signal._digital = np.frombuffer(raw, dtype=np.uint8).copy() # type: ignore[assignment]
return signal # type: ignore[return-value]
@dataclass
class _EdfTAL:
onset: float
duration: float | None
texts: list[str]
def to_bytes(self) -> bytes:
timing = _encode_annotation_onset(self.onset)
if self.duration is not None:
timing += f"\x15{_encode_annotation_duration(self.duration)}"
texts_joined = "\x14".join(self.texts)
return f"{timing}\x14{texts_joined}\x14".encode()
@dataclass
class _EdfAnnotationsDataRecord:
tals: list[_EdfTAL]
def to_bytes(self) -> bytes:
return b"\x00".join(tal.to_bytes() for tal in self.tals) + b"\x00"
@classmethod
def from_bytes(cls, raw: bytes) -> _EdfAnnotationsDataRecord:
tals: list[_EdfTAL] = []
matches: list[tuple[str, str, str]] = _ANNOTATIONS_PATTERN.findall(raw.decode())
if not matches and raw.replace(b"\x00", b""):
raise ValueError(f"No valid annotations found in {raw!r}")
for onset, duration, texts in matches:
tals.append(
_EdfTAL(
float(onset),
float(duration) if duration else None,
list(texts.split("\x14")),
)
)
return cls(tals)
@property
def annotations(self) -> list[EdfAnnotation]:
return [
EdfAnnotation(tal.onset, tal.duration, text)
for tal in self.tals
for text in tal.texts
]
def drop_annotations_with_text(self, text: str) -> None:
for tal in self.tals:
while text in tal.texts:
tal.texts.remove(text)
self.tals = [tal for tal in self.tals if tal.texts]
|