Aluode's picture
Upload folder using huggingface_hub
3bb804c verified
raw
history blame
5.47 kB
from __future__ import annotations
import math
import re
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, NamedTuple, TypeVar, Union
import numpy as np
from edfio.edf_signal import BdfSignal, EdfSignal
_ANNOTATIONS_PATTERN = re.compile(
"""
([+-]\\d+(?:\\.?\\d+)?) # onset
(?:\x15(\\d+(?:\\.?\\d+)?))? # duration, optional
(?:\x14(.*?)) # annotation texts
\x14\x00 # terminator
""",
re.VERBOSE,
)
def _encode_annotation_onset(onset: float) -> str:
string = f"{onset:+.12f}".rstrip("0")
if string[-1] == ".":
return string[:-1]
return string
def _encode_annotation_duration(duration: float) -> str:
if duration < 0:
raise ValueError(f"Annotation duration must be positive, is {duration}")
string = f"{duration:.12f}".rstrip("0")
if string[-1] == ".":
return string[:-1]
return string
class EdfAnnotation(NamedTuple):
"""A single EDF+ annotation.
Parameters
----------
onset : float
The annotation onset in seconds from recording start.
duration : float | None
The annotation duration in seconds (`None` if annotation has no duration).
text : str
The annotation text, can be empty.
"""
onset: float
duration: float | None
text: str
def __lt__(self, other: Any) -> bool:
if not isinstance(other, EdfAnnotation):
return NotImplemented # pragma: no cover
return (
self.onset,
-1 if self.duration is None else self.duration,
self.text,
) < (
other.onset,
-1 if other.duration is None else other.duration,
other.text,
)
_Signal = TypeVar("_Signal", bound=Union[EdfSignal, BdfSignal])
def _create_annotations_signal(
annotations: Iterable[EdfAnnotation],
*,
num_data_records: int,
data_record_duration: float,
signal_class: type[_Signal],
with_timestamps: bool = True,
subsecond_offset: float = 0,
) -> _Signal:
bytes_per_sample = signal_class._bytes_per_sample
data_record_starts = np.arange(num_data_records) * data_record_duration
# list.pop() is O(1) and list.pop(0) is O(n), so using a reversed list is faster
annotations = sorted(annotations, reverse=True)
data_records = []
for i, start in enumerate(data_record_starts):
end = start + data_record_duration
tals: list[_EdfTAL] = []
if with_timestamps:
tals.append(_EdfTAL(start + subsecond_offset, None, [""]))
while annotations and (
annotations[-1].onset < end or i == num_data_records - 1
):
ann = annotations.pop()
tals.append(
_EdfTAL(
ann.onset + subsecond_offset,
ann.duration,
[ann.text],
)
)
data_records.append(_EdfAnnotationsDataRecord(tals).to_bytes())
maxlen = max(len(data_record) for data_record in data_records)
maxlen = math.ceil(maxlen / bytes_per_sample) * bytes_per_sample
raw = b"".join(dr.ljust(maxlen, b"\x00") for dr in data_records)
divisor = data_record_duration if data_record_duration else 1
signal = signal_class(
np.arange(1.0), # placeholder signal, as argument `data` is non-optional
sampling_frequency=maxlen // bytes_per_sample / divisor,
physical_range=signal_class._default_digital_range,
)
signal._label = f"{signal_class._fmt} Annotations ".encode()
signal._set_samples_per_data_record(maxlen // bytes_per_sample)
signal._digital = np.frombuffer(raw, dtype=np.uint8).copy() # type: ignore[assignment]
return signal # type: ignore[return-value]
@dataclass
class _EdfTAL:
onset: float
duration: float | None
texts: list[str]
def to_bytes(self) -> bytes:
timing = _encode_annotation_onset(self.onset)
if self.duration is not None:
timing += f"\x15{_encode_annotation_duration(self.duration)}"
texts_joined = "\x14".join(self.texts)
return f"{timing}\x14{texts_joined}\x14".encode()
@dataclass
class _EdfAnnotationsDataRecord:
tals: list[_EdfTAL]
def to_bytes(self) -> bytes:
return b"\x00".join(tal.to_bytes() for tal in self.tals) + b"\x00"
@classmethod
def from_bytes(cls, raw: bytes) -> _EdfAnnotationsDataRecord:
tals: list[_EdfTAL] = []
matches: list[tuple[str, str, str]] = _ANNOTATIONS_PATTERN.findall(raw.decode())
if not matches and raw.replace(b"\x00", b""):
raise ValueError(f"No valid annotations found in {raw!r}")
for onset, duration, texts in matches:
tals.append(
_EdfTAL(
float(onset),
float(duration) if duration else None,
list(texts.split("\x14")),
)
)
return cls(tals)
@property
def annotations(self) -> list[EdfAnnotation]:
return [
EdfAnnotation(tal.onset, tal.duration, text)
for tal in self.tals
for text in tal.texts
]
def drop_annotations_with_text(self, text: str) -> None:
for tal in self.tals:
while text in tal.texts:
tal.texts.remove(text)
self.tals = [tal for tal in self.tals if tal.texts]