Spaces:
Running
Running
| from __future__ import annotations | |
| import math | |
| import re | |
| from collections.abc import Iterable | |
| from dataclasses import dataclass | |
| from typing import Any, NamedTuple, TypeVar, Union | |
| import numpy as np | |
| from edfio.edf_signal import BdfSignal, EdfSignal | |
| _ANNOTATIONS_PATTERN = re.compile( | |
| """ | |
| ([+-]\\d+(?:\\.?\\d+)?) # onset | |
| (?:\x15(\\d+(?:\\.?\\d+)?))? # duration, optional | |
| (?:\x14(.*?)) # annotation texts | |
| \x14\x00 # terminator | |
| """, | |
| re.VERBOSE, | |
| ) | |
| def _encode_annotation_onset(onset: float) -> str: | |
| string = f"{onset:+.12f}".rstrip("0") | |
| if string[-1] == ".": | |
| return string[:-1] | |
| return string | |
| def _encode_annotation_duration(duration: float) -> str: | |
| if duration < 0: | |
| raise ValueError(f"Annotation duration must be positive, is {duration}") | |
| string = f"{duration:.12f}".rstrip("0") | |
| if string[-1] == ".": | |
| return string[:-1] | |
| return string | |
| class EdfAnnotation(NamedTuple): | |
| """A single EDF+ annotation. | |
| Parameters | |
| ---------- | |
| onset : float | |
| The annotation onset in seconds from recording start. | |
| duration : float | None | |
| The annotation duration in seconds (`None` if annotation has no duration). | |
| text : str | |
| The annotation text, can be empty. | |
| """ | |
| onset: float | |
| duration: float | None | |
| text: str | |
| def __lt__(self, other: Any) -> bool: | |
| if not isinstance(other, EdfAnnotation): | |
| return NotImplemented # pragma: no cover | |
| return ( | |
| self.onset, | |
| -1 if self.duration is None else self.duration, | |
| self.text, | |
| ) < ( | |
| other.onset, | |
| -1 if other.duration is None else other.duration, | |
| other.text, | |
| ) | |
| _Signal = TypeVar("_Signal", bound=Union[EdfSignal, BdfSignal]) | |
| def _create_annotations_signal( | |
| annotations: Iterable[EdfAnnotation], | |
| *, | |
| num_data_records: int, | |
| data_record_duration: float, | |
| signal_class: type[_Signal], | |
| with_timestamps: bool = True, | |
| subsecond_offset: float = 0, | |
| ) -> _Signal: | |
| bytes_per_sample = signal_class._bytes_per_sample | |
| data_record_starts = np.arange(num_data_records) * data_record_duration | |
| # list.pop() is O(1) and list.pop(0) is O(n), so using a reversed list is faster | |
| annotations = sorted(annotations, reverse=True) | |
| data_records = [] | |
| for i, start in enumerate(data_record_starts): | |
| end = start + data_record_duration | |
| tals: list[_EdfTAL] = [] | |
| if with_timestamps: | |
| tals.append(_EdfTAL(start + subsecond_offset, None, [""])) | |
| while annotations and ( | |
| annotations[-1].onset < end or i == num_data_records - 1 | |
| ): | |
| ann = annotations.pop() | |
| tals.append( | |
| _EdfTAL( | |
| ann.onset + subsecond_offset, | |
| ann.duration, | |
| [ann.text], | |
| ) | |
| ) | |
| data_records.append(_EdfAnnotationsDataRecord(tals).to_bytes()) | |
| maxlen = max(len(data_record) for data_record in data_records) | |
| maxlen = math.ceil(maxlen / bytes_per_sample) * bytes_per_sample | |
| raw = b"".join(dr.ljust(maxlen, b"\x00") for dr in data_records) | |
| divisor = data_record_duration if data_record_duration else 1 | |
| signal = signal_class( | |
| np.arange(1.0), # placeholder signal, as argument `data` is non-optional | |
| sampling_frequency=maxlen // bytes_per_sample / divisor, | |
| physical_range=signal_class._default_digital_range, | |
| ) | |
| signal._label = f"{signal_class._fmt} Annotations ".encode() | |
| signal._set_samples_per_data_record(maxlen // bytes_per_sample) | |
| signal._digital = np.frombuffer(raw, dtype=np.uint8).copy() # type: ignore[assignment] | |
| return signal # type: ignore[return-value] | |
| class _EdfTAL: | |
| onset: float | |
| duration: float | None | |
| texts: list[str] | |
| def to_bytes(self) -> bytes: | |
| timing = _encode_annotation_onset(self.onset) | |
| if self.duration is not None: | |
| timing += f"\x15{_encode_annotation_duration(self.duration)}" | |
| texts_joined = "\x14".join(self.texts) | |
| return f"{timing}\x14{texts_joined}\x14".encode() | |
| class _EdfAnnotationsDataRecord: | |
| tals: list[_EdfTAL] | |
| def to_bytes(self) -> bytes: | |
| return b"\x00".join(tal.to_bytes() for tal in self.tals) + b"\x00" | |
| def from_bytes(cls, raw: bytes) -> _EdfAnnotationsDataRecord: | |
| tals: list[_EdfTAL] = [] | |
| matches: list[tuple[str, str, str]] = _ANNOTATIONS_PATTERN.findall(raw.decode()) | |
| if not matches and raw.replace(b"\x00", b""): | |
| raise ValueError(f"No valid annotations found in {raw!r}") | |
| for onset, duration, texts in matches: | |
| tals.append( | |
| _EdfTAL( | |
| float(onset), | |
| float(duration) if duration else None, | |
| list(texts.split("\x14")), | |
| ) | |
| ) | |
| return cls(tals) | |
| def annotations(self) -> list[EdfAnnotation]: | |
| return [ | |
| EdfAnnotation(tal.onset, tal.duration, text) | |
| for tal in self.tals | |
| for text in tal.texts | |
| ] | |
| def drop_annotations_with_text(self, text: str) -> None: | |
| for tal in self.tals: | |
| while text in tal.texts: | |
| tal.texts.remove(text) | |
| self.tals = [tal for tal in self.tals if tal.texts] | |