from __future__ import annotations import math import re from collections.abc import Iterable from dataclasses import dataclass from typing import Any, NamedTuple, TypeVar, Union import numpy as np from edfio.edf_signal import BdfSignal, EdfSignal _ANNOTATIONS_PATTERN = re.compile( """ ([+-]\\d+(?:\\.?\\d+)?) # onset (?:\x15(\\d+(?:\\.?\\d+)?))? # duration, optional (?:\x14(.*?)) # annotation texts \x14\x00 # terminator """, re.VERBOSE, ) def _encode_annotation_onset(onset: float) -> str: string = f"{onset:+.12f}".rstrip("0") if string[-1] == ".": return string[:-1] return string def _encode_annotation_duration(duration: float) -> str: if duration < 0: raise ValueError(f"Annotation duration must be positive, is {duration}") string = f"{duration:.12f}".rstrip("0") if string[-1] == ".": return string[:-1] return string class EdfAnnotation(NamedTuple): """A single EDF+ annotation. Parameters ---------- onset : float The annotation onset in seconds from recording start. duration : float | None The annotation duration in seconds (`None` if annotation has no duration). text : str The annotation text, can be empty. """ onset: float duration: float | None text: str def __lt__(self, other: Any) -> bool: if not isinstance(other, EdfAnnotation): return NotImplemented # pragma: no cover return ( self.onset, -1 if self.duration is None else self.duration, self.text, ) < ( other.onset, -1 if other.duration is None else other.duration, other.text, ) _Signal = TypeVar("_Signal", bound=Union[EdfSignal, BdfSignal]) def _create_annotations_signal( annotations: Iterable[EdfAnnotation], *, num_data_records: int, data_record_duration: float, signal_class: type[_Signal], with_timestamps: bool = True, subsecond_offset: float = 0, ) -> _Signal: bytes_per_sample = signal_class._bytes_per_sample data_record_starts = np.arange(num_data_records) * data_record_duration # list.pop() is O(1) and list.pop(0) is O(n), so using a reversed list is faster annotations = sorted(annotations, reverse=True) data_records = [] for i, start in enumerate(data_record_starts): end = start + data_record_duration tals: list[_EdfTAL] = [] if with_timestamps: tals.append(_EdfTAL(start + subsecond_offset, None, [""])) while annotations and ( annotations[-1].onset < end or i == num_data_records - 1 ): ann = annotations.pop() tals.append( _EdfTAL( ann.onset + subsecond_offset, ann.duration, [ann.text], ) ) data_records.append(_EdfAnnotationsDataRecord(tals).to_bytes()) maxlen = max(len(data_record) for data_record in data_records) maxlen = math.ceil(maxlen / bytes_per_sample) * bytes_per_sample raw = b"".join(dr.ljust(maxlen, b"\x00") for dr in data_records) divisor = data_record_duration if data_record_duration else 1 signal = signal_class( np.arange(1.0), # placeholder signal, as argument `data` is non-optional sampling_frequency=maxlen // bytes_per_sample / divisor, physical_range=signal_class._default_digital_range, ) signal._label = f"{signal_class._fmt} Annotations ".encode() signal._set_samples_per_data_record(maxlen // bytes_per_sample) signal._digital = np.frombuffer(raw, dtype=np.uint8).copy() # type: ignore[assignment] return signal # type: ignore[return-value] @dataclass class _EdfTAL: onset: float duration: float | None texts: list[str] def to_bytes(self) -> bytes: timing = _encode_annotation_onset(self.onset) if self.duration is not None: timing += f"\x15{_encode_annotation_duration(self.duration)}" texts_joined = "\x14".join(self.texts) return f"{timing}\x14{texts_joined}\x14".encode() @dataclass class _EdfAnnotationsDataRecord: tals: list[_EdfTAL] def to_bytes(self) -> bytes: return b"\x00".join(tal.to_bytes() for tal in self.tals) + b"\x00" @classmethod def from_bytes(cls, raw: bytes) -> _EdfAnnotationsDataRecord: tals: list[_EdfTAL] = [] matches: list[tuple[str, str, str]] = _ANNOTATIONS_PATTERN.findall(raw.decode()) if not matches and raw.replace(b"\x00", b""): raise ValueError(f"No valid annotations found in {raw!r}") for onset, duration, texts in matches: tals.append( _EdfTAL( float(onset), float(duration) if duration else None, list(texts.split("\x14")), ) ) return cls(tals) @property def annotations(self) -> list[EdfAnnotation]: return [ EdfAnnotation(tal.onset, tal.duration, text) for tal in self.tals for text in tal.texts ] def drop_annotations_with_text(self, text: str) -> None: for tal in self.tals: while text in tal.texts: tal.texts.remove(text) self.tals = [tal for tal in self.tals if tal.texts]