import io import json import os import threading import warnings from functools import lru_cache from typing import Optional import boto3 import numpy as np import pandas as pd from dotenv import load_dotenv load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) warnings.filterwarnings("ignore", category=UserWarning, message=".*asynchronous.*") MANIFEST_DATASET = "tankalapavankalyan/eeg-corpus-manifest" PARQUET_BASE = ( "https://huggingface.co/datasets/tankalapavankalyan/eeg-corpus-manifest" "/resolve/refs%2Fconvert%2Fparquet" ) TUAR_DATASET_IDS = ["tuh_eeg_artifact"] ARTIFACT_LABEL_MAP = { "eyem": "Eye Movement", "musc": "Muscle", "elec": "Electrode Pop", "chew": "Chewing", "shiv": "Shiver", "null": "Clean", "elpp": "Electrode Pop", "artf": "Artifact (Generic)", "bckg": "Background", "eyem_musc": "Eye Movement + Muscle", "musc_elec": "Muscle + Electrode Pop", "eyem_elec": "Eye Movement + Electrode Pop", "eyem_chew": "Eye Movement + Chewing", "chew_elec": "Chewing + Electrode Pop", "chew_musc": "Chewing + Muscle", "eyem_shiv": "Eye Movement + Shiver", "shiv_elec": "Shiver + Electrode Pop", } ARTIFACT_COLORS = { "Eye Movement": "rgba(30, 144, 255, 0.25)", "Muscle": "rgba(220, 20, 60, 0.25)", "Electrode Pop": "rgba(255, 165, 0, 0.25)", "Chewing": "rgba(50, 205, 50, 0.25)", "Shiver": "rgba(148, 103, 189, 0.25)", "Artifact (Generic)": "rgba(128, 128, 128, 0.25)", "Background": "rgba(200, 200, 200, 0.08)", "Clean": "rgba(200, 200, 200, 0.08)", "Eye Movement + Muscle": "rgba(125, 82, 158, 0.25)", "Muscle + Electrode Pop": "rgba(238, 93, 30, 0.25)", "Eye Movement + Electrode Pop": "rgba(143, 155, 128, 0.25)", "Eye Movement + Chewing": "rgba(40, 175, 153, 0.25)", "Chewing + Electrode Pop": "rgba(153, 185, 30, 0.25)", "Chewing + Muscle": "rgba(135, 113, 56, 0.25)", "Eye Movement + Shiver": "rgba(89, 124, 222, 0.25)", "Shiver + Electrode Pop": "rgba(202, 134, 95, 0.25)", } _boto_client = None _thread_local = threading.local() def _get_boto_client(): global _boto_client if _boto_client is None: key = os.environ.get("AWS_ACCESS_KEY_ID") secret = os.environ.get("AWS_SECRET_ACCESS_KEY") region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") if key and secret: _boto_client = boto3.client("s3", aws_access_key_id=key, aws_secret_access_key=secret, region_name=region) else: _boto_client = boto3.client("s3", region_name=region) return _boto_client def _get_thread_s3fs(): """Create a per-thread s3fs instance to avoid async event loop conflicts.""" import s3fs fs = getattr(_thread_local, "fs", None) if fs is None: key = os.environ.get("AWS_ACCESS_KEY_ID") secret = os.environ.get("AWS_SECRET_ACCESS_KEY") region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") if key and secret: fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs={"region_name": region}, asynchronous=False) else: fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": region}, asynchronous=False) _thread_local.fs = fs return fs def reset_s3fs(): global _boto_client _boto_client = None _thread_local.__dict__.clear() _zarr_cache.clear() _scale_cache.clear() _annotation_cache.clear() MANIFEST_COLUMNS = [ "recording_id", "dataset_id", "subject_id_in_dataset", "session_id", "run_id", "task", "archival_uri", "archival_format", "duration_s", "n_channels", "n_eeg_channels", "sampling_rate_hz", "reference", "montage_name", "recording_type", "channel_names", "canonical_uri", "conversion_status", "roundtrip_class", ] def get_tuar_recordings() -> pd.DataFrame: url = f"{PARQUET_BASE}/recordings/train/0000.parquet" df = pd.read_parquet(url, columns=MANIFEST_COLUMNS) mask = df["dataset_id"].isin(TUAR_DATASET_IDS) & (df["conversion_status"] == "ok") result = df[mask].copy() del df result = result.sort_values("subject_id_in_dataset").reset_index(drop=True) return result def get_recording_display_list(df: pd.DataFrame) -> list[str]: entries = [] for _, row in df.iterrows(): label = ( f"{row['recording_id'][:8]}… | " f"{row['dataset_id']} | " f"subj={row['subject_id_in_dataset']} | " f"ses={row.get('session_id', 'N/A')} | " f"dur={row['duration_s']:.0f}s | " f"{row['n_channels']:.0f}ch @ {row['sampling_rate_hz']:.0f}Hz" ) entries.append(label) return entries # --- Signal access via direct Zarr + S3 --- _zarr_cache: dict[str, object] = {} _scale_cache: dict[str, tuple[np.ndarray, np.ndarray]] = {} def _register_flac_codec(): """Register the FLAC codec with zarr v3 if not already done.""" try: from zarr.registry import get_codec_class get_codec_class("numcodecs.flac") except KeyError: import numcodecs from flac_numcodecs import Flac numcodecs.register_codec(Flac) from zarr.codecs.numcodecs._codecs import _NumcodecsArrayBytesCodec from zarr.codecs import register_codec class FlacCodec(_NumcodecsArrayBytesCodec): codec_name = "numcodecs.flac" def __init__(self, **kwargs): super().__init__(**kwargs) register_codec("numcodecs.flac", FlacCodec) def _open_zarr(canonical_uri: str): if canonical_uri in _zarr_cache: return _zarr_cache[canonical_uri] _register_flac_codec() fs = _get_thread_s3fs() s3_path = canonical_uri.replace("s3://", "") if not s3_path.endswith("/"): s3_path += "/" import zarr fsspec_store = zarr.storage.FsspecStore(fs=fs, path=s3_path, read_only=True) root = zarr.open_group(fsspec_store, mode="r") _zarr_cache[canonical_uri] = root if len(_zarr_cache) > 50: oldest = next(iter(_zarr_cache)) del _zarr_cache[oldest] _scale_cache.pop(oldest, None) return root def _get_scale_offset(canonical_uri: str): if canonical_uri in _scale_cache: return _scale_cache[canonical_uri] root = _open_zarr(canonical_uri) ch_grp = root["channels"] phys_min = np.array(ch_grp["physical_min"][:], dtype=np.float64) phys_max = np.array(ch_grp["physical_max"][:], dtype=np.float64) dig_min = np.array(ch_grp["digital_min"][:], dtype=np.float64) dig_max = np.array(ch_grp["digital_max"][:], dtype=np.float64) scale = (phys_max - phys_min) / (dig_max - dig_min + 1e-12) offset = phys_min - dig_min * scale _scale_cache[canonical_uri] = (scale.astype(np.float32), offset.astype(np.float32)) return _scale_cache[canonical_uri] def read_signal_window( canonical_uri: str, start_sample: int, end_sample: int, channel_indices: Optional[list[int]] = None, ) -> np.ndarray: root = _open_zarr(canonical_uri) sig_arr = root["signal"] if channel_indices is not None: raw = np.array(sig_arr[channel_indices, start_sample:end_sample], dtype=np.float32) else: raw = np.array(sig_arr[:, start_sample:end_sample], dtype=np.float32) scale, offset = _get_scale_offset(canonical_uri) if channel_indices is not None: scale = scale[channel_indices] offset = offset[channel_indices] data = raw * scale[:, None] + offset[:, None] kernel_size = 5 if data.shape[1] > kernel_size: kernel = np.ones(kernel_size) / kernel_size for i in range(data.shape[0]): data[i] = np.convolve(data[i], kernel, mode="same") return data def get_channel_names(canonical_uri: str) -> list[str]: root = _open_zarr(canonical_uri) return list(root["channels"]["name"][:]) def get_store_metadata(canonical_uri: str) -> dict: root = _open_zarr(canonical_uri) attrs = dict(root.attrs) return { "n_channels": root["signal"].shape[0], "n_samples": root["signal"].shape[1], "sampling_rate_hz": attrs.get("sampling_rate_hz"), "duration_s": attrs.get("duration_s"), "channel_names": list(root["channels"]["name"][:]), "reference": attrs.get("reference"), "montage_name": attrs.get("montage_name"), "recording_type": attrs.get("recording_type"), "manufacturer": attrs.get("manufacturer"), "source_uri": attrs.get("source_uri"), } # --- TUAR artifact annotations from CSV companion files --- _annotation_cache: dict[str, list[dict]] = {} def _get_csv_path_from_source_uri(source_uri: str) -> Optional[str]: if not source_uri or not source_uri.endswith(".edf"): return None return source_uri.replace("s3://", "").rsplit(".", 1)[0] + ".csv" def get_annotations(canonical_uri: str, source_uri: Optional[str] = None) -> list[dict]: cache_key = canonical_uri if cache_key in _annotation_cache: return _annotation_cache[cache_key] if source_uri is None: try: rec = _open_recording(canonical_uri) source_uri = rec.metadata.source_uri except Exception: _annotation_cache[cache_key] = [] return [] csv_path = _get_csv_path_from_source_uri(source_uri) if csv_path is None: _annotation_cache[cache_key] = [] return [] try: bucket, key = csv_path.split("/", 1) resp = _get_boto_client().get_object(Bucket=bucket, Key=key) raw = resp["Body"].read().decode("utf-8") except Exception: _annotation_cache[cache_key] = [] return [] annotations = _parse_tuar_csv(raw) _annotation_cache[cache_key] = annotations return annotations def preload_all_annotations(df: pd.DataFrame) -> None: """Bulk-fetch all CSV annotation files in one batch S3 call.""" paths_map: dict[str, str] = {} for _, row in df.iterrows(): canonical_uri = row.get("canonical_uri", "") archival_uri = row.get("archival_uri", "") if not canonical_uri or canonical_uri in _annotation_cache: continue csv_path = _get_csv_path_from_source_uri(archival_uri) if csv_path: paths_map[csv_path] = canonical_uri if not paths_map: return client = _get_boto_client() for csv_path, canonical_uri in paths_map.items(): try: bucket, key = csv_path.split("/", 1) resp = client.get_object(Bucket=bucket, Key=key) raw = resp["Body"].read().decode("utf-8") _annotation_cache[canonical_uri] = _parse_tuar_csv(raw) except Exception: _annotation_cache[canonical_uri] = [] def _parse_tuar_csv(raw: str) -> list[dict]: lines = [l for l in raw.strip().split("\n") if not l.startswith("#") and l.strip()] if not lines: return [] header_line = lines[0] if "channel" in header_line and "start_time" in header_line: lines = lines[1:] seen = set() annotations = [] for line in lines: parts = line.strip().split(",") if len(parts) < 4: continue channel = parts[0].strip() try: start = float(parts[1].strip()) stop = float(parts[2].strip()) except ValueError: continue raw_label = parts[3].strip().lower() confidence = float(parts[4].strip()) if len(parts) > 4 else 1.0 dedup_key = (round(start, 3), round(stop, 3), raw_label) if dedup_key in seen: continue seen.add(dedup_key) label = ARTIFACT_LABEL_MAP.get(raw_label, raw_label.title()) color = ARTIFACT_COLORS.get(label, "rgba(128, 128, 128, 0.2)") annotations.append({ "onset_s": start, "duration_s": stop - start, "end_s": stop, "raw_label": raw_label, "label": label, "color": color, "channel": channel, "confidence": confidence, }) annotations.sort(key=lambda a: a["onset_s"]) return annotations def get_annotations_in_window( canonical_uri: str, start_s: float, end_s: float, source_uri: Optional[str] = None, ) -> list[dict]: all_ann = get_annotations(canonical_uri, source_uri) return [a for a in all_ann if a["end_s"] > start_s and a["onset_s"] < end_s] def get_recording_info(row: pd.Series) -> dict: channel_names = row.get("channel_names", []) if isinstance(channel_names, str): try: channel_names = json.loads(channel_names) except (json.JSONDecodeError, TypeError): channel_names = [] if not isinstance(channel_names, list): channel_names = list(channel_names) return { "recording_id": row["recording_id"], "dataset_id": row["dataset_id"], "subject": row.get("subject_id_in_dataset", "N/A"), "session": row.get("session_id", "N/A"), "task": row.get("task", "N/A"), "duration_s": row.get("duration_s", 0), "n_channels": row.get("n_channels", 0), "n_eeg_channels": row.get("n_eeg_channels", 0), "sampling_rate_hz": row.get("sampling_rate_hz", 0), "reference": row.get("reference", "N/A"), "montage_name": row.get("montage_name", "N/A"), "recording_type": row.get("recording_type", "N/A"), "archival_format": row.get("archival_format", "N/A"), "canonical_uri": row.get("canonical_uri", ""), "archival_uri": row.get("archival_uri", ""), "channel_names": channel_names, "roundtrip_class": row.get("roundtrip_class", "N/A"), }