tuar-eeg-dashboard / data_loader.py
Gunin09's picture
Remove Eye Blink from label map and legend (not in TUAR)
58c475e verified
import io
import json
import os
import threading
import warnings
from functools import lru_cache
from typing import Optional
import boto3
import numpy as np
import pandas as pd
from dotenv import load_dotenv
load_dotenv(os.path.join(os.path.dirname(__file__), ".env"))
warnings.filterwarnings("ignore", category=UserWarning, message=".*asynchronous.*")
MANIFEST_DATASET = "tankalapavankalyan/eeg-corpus-manifest"
PARQUET_BASE = (
"https://huggingface.co/datasets/tankalapavankalyan/eeg-corpus-manifest"
"/resolve/refs%2Fconvert%2Fparquet"
)
TUAR_DATASET_IDS = ["tuh_eeg_artifact"]
ARTIFACT_LABEL_MAP = {
"eyem": "Eye Movement",
"musc": "Muscle",
"elec": "Electrode Pop",
"chew": "Chewing",
"shiv": "Shiver",
"null": "Clean",
"elpp": "Electrode Pop",
"artf": "Artifact (Generic)",
"bckg": "Background",
"eyem_musc": "Eye Movement + Muscle",
"musc_elec": "Muscle + Electrode Pop",
"eyem_elec": "Eye Movement + Electrode Pop",
"eyem_chew": "Eye Movement + Chewing",
"chew_elec": "Chewing + Electrode Pop",
"chew_musc": "Chewing + Muscle",
"eyem_shiv": "Eye Movement + Shiver",
"shiv_elec": "Shiver + Electrode Pop",
}
ARTIFACT_COLORS = {
"Eye Movement": "rgba(30, 144, 255, 0.25)",
"Muscle": "rgba(220, 20, 60, 0.25)",
"Electrode Pop": "rgba(255, 165, 0, 0.25)",
"Chewing": "rgba(50, 205, 50, 0.25)",
"Shiver": "rgba(148, 103, 189, 0.25)",
"Artifact (Generic)": "rgba(128, 128, 128, 0.25)",
"Background": "rgba(200, 200, 200, 0.08)",
"Clean": "rgba(200, 200, 200, 0.08)",
"Eye Movement + Muscle": "rgba(125, 82, 158, 0.25)",
"Muscle + Electrode Pop": "rgba(238, 93, 30, 0.25)",
"Eye Movement + Electrode Pop": "rgba(143, 155, 128, 0.25)",
"Eye Movement + Chewing": "rgba(40, 175, 153, 0.25)",
"Chewing + Electrode Pop": "rgba(153, 185, 30, 0.25)",
"Chewing + Muscle": "rgba(135, 113, 56, 0.25)",
"Eye Movement + Shiver": "rgba(89, 124, 222, 0.25)",
"Shiver + Electrode Pop": "rgba(202, 134, 95, 0.25)",
}
_boto_client = None
_thread_local = threading.local()
def _get_boto_client():
global _boto_client
if _boto_client is None:
key = os.environ.get("AWS_ACCESS_KEY_ID")
secret = os.environ.get("AWS_SECRET_ACCESS_KEY")
region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
if key and secret:
_boto_client = boto3.client("s3", aws_access_key_id=key, aws_secret_access_key=secret, region_name=region)
else:
_boto_client = boto3.client("s3", region_name=region)
return _boto_client
def _get_thread_s3fs():
"""Create a per-thread s3fs instance to avoid async event loop conflicts."""
import s3fs
fs = getattr(_thread_local, "fs", None)
if fs is None:
key = os.environ.get("AWS_ACCESS_KEY_ID")
secret = os.environ.get("AWS_SECRET_ACCESS_KEY")
region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
if key and secret:
fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs={"region_name": region}, asynchronous=False)
else:
fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": region}, asynchronous=False)
_thread_local.fs = fs
return fs
def reset_s3fs():
global _boto_client
_boto_client = None
_thread_local.__dict__.clear()
_zarr_cache.clear()
_scale_cache.clear()
_annotation_cache.clear()
MANIFEST_COLUMNS = [
"recording_id", "dataset_id", "subject_id_in_dataset", "session_id",
"run_id", "task", "archival_uri", "archival_format", "duration_s",
"n_channels", "n_eeg_channels", "sampling_rate_hz", "reference",
"montage_name", "recording_type", "channel_names", "canonical_uri",
"conversion_status", "roundtrip_class",
]
def get_tuar_recordings() -> pd.DataFrame:
url = f"{PARQUET_BASE}/recordings/train/0000.parquet"
df = pd.read_parquet(url, columns=MANIFEST_COLUMNS)
mask = df["dataset_id"].isin(TUAR_DATASET_IDS) & (df["conversion_status"] == "ok")
result = df[mask].copy()
del df
result = result.sort_values("subject_id_in_dataset").reset_index(drop=True)
return result
def get_recording_display_list(df: pd.DataFrame) -> list[str]:
entries = []
for _, row in df.iterrows():
label = (
f"{row['recording_id'][:8]}… | "
f"{row['dataset_id']} | "
f"subj={row['subject_id_in_dataset']} | "
f"ses={row.get('session_id', 'N/A')} | "
f"dur={row['duration_s']:.0f}s | "
f"{row['n_channels']:.0f}ch @ {row['sampling_rate_hz']:.0f}Hz"
)
entries.append(label)
return entries
# --- Signal access via direct Zarr + S3 ---
_zarr_cache: dict[str, object] = {}
_scale_cache: dict[str, tuple[np.ndarray, np.ndarray]] = {}
def _register_flac_codec():
"""Register the FLAC codec with zarr v3 if not already done."""
try:
from zarr.registry import get_codec_class
get_codec_class("numcodecs.flac")
except KeyError:
import numcodecs
from flac_numcodecs import Flac
numcodecs.register_codec(Flac)
from zarr.codecs.numcodecs._codecs import _NumcodecsArrayBytesCodec
from zarr.codecs import register_codec
class FlacCodec(_NumcodecsArrayBytesCodec):
codec_name = "numcodecs.flac"
def __init__(self, **kwargs):
super().__init__(**kwargs)
register_codec("numcodecs.flac", FlacCodec)
def _open_zarr(canonical_uri: str):
if canonical_uri in _zarr_cache:
return _zarr_cache[canonical_uri]
_register_flac_codec()
fs = _get_thread_s3fs()
s3_path = canonical_uri.replace("s3://", "")
if not s3_path.endswith("/"):
s3_path += "/"
import zarr
fsspec_store = zarr.storage.FsspecStore(fs=fs, path=s3_path, read_only=True)
root = zarr.open_group(fsspec_store, mode="r")
_zarr_cache[canonical_uri] = root
if len(_zarr_cache) > 50:
oldest = next(iter(_zarr_cache))
del _zarr_cache[oldest]
_scale_cache.pop(oldest, None)
return root
def _get_scale_offset(canonical_uri: str):
if canonical_uri in _scale_cache:
return _scale_cache[canonical_uri]
root = _open_zarr(canonical_uri)
ch_grp = root["channels"]
phys_min = np.array(ch_grp["physical_min"][:], dtype=np.float64)
phys_max = np.array(ch_grp["physical_max"][:], dtype=np.float64)
dig_min = np.array(ch_grp["digital_min"][:], dtype=np.float64)
dig_max = np.array(ch_grp["digital_max"][:], dtype=np.float64)
scale = (phys_max - phys_min) / (dig_max - dig_min + 1e-12)
offset = phys_min - dig_min * scale
_scale_cache[canonical_uri] = (scale.astype(np.float32), offset.astype(np.float32))
return _scale_cache[canonical_uri]
def read_signal_window(
canonical_uri: str,
start_sample: int,
end_sample: int,
channel_indices: Optional[list[int]] = None,
) -> np.ndarray:
root = _open_zarr(canonical_uri)
sig_arr = root["signal"]
if channel_indices is not None:
raw = np.array(sig_arr[channel_indices, start_sample:end_sample], dtype=np.float32)
else:
raw = np.array(sig_arr[:, start_sample:end_sample], dtype=np.float32)
scale, offset = _get_scale_offset(canonical_uri)
if channel_indices is not None:
scale = scale[channel_indices]
offset = offset[channel_indices]
data = raw * scale[:, None] + offset[:, None]
kernel_size = 5
if data.shape[1] > kernel_size:
kernel = np.ones(kernel_size) / kernel_size
for i in range(data.shape[0]):
data[i] = np.convolve(data[i], kernel, mode="same")
return data
def get_channel_names(canonical_uri: str) -> list[str]:
root = _open_zarr(canonical_uri)
return list(root["channels"]["name"][:])
def get_store_metadata(canonical_uri: str) -> dict:
root = _open_zarr(canonical_uri)
attrs = dict(root.attrs)
return {
"n_channels": root["signal"].shape[0],
"n_samples": root["signal"].shape[1],
"sampling_rate_hz": attrs.get("sampling_rate_hz"),
"duration_s": attrs.get("duration_s"),
"channel_names": list(root["channels"]["name"][:]),
"reference": attrs.get("reference"),
"montage_name": attrs.get("montage_name"),
"recording_type": attrs.get("recording_type"),
"manufacturer": attrs.get("manufacturer"),
"source_uri": attrs.get("source_uri"),
}
# --- TUAR artifact annotations from CSV companion files ---
_annotation_cache: dict[str, list[dict]] = {}
def _get_csv_path_from_source_uri(source_uri: str) -> Optional[str]:
if not source_uri or not source_uri.endswith(".edf"):
return None
return source_uri.replace("s3://", "").rsplit(".", 1)[0] + ".csv"
def get_annotations(canonical_uri: str, source_uri: Optional[str] = None) -> list[dict]:
cache_key = canonical_uri
if cache_key in _annotation_cache:
return _annotation_cache[cache_key]
if source_uri is None:
try:
rec = _open_recording(canonical_uri)
source_uri = rec.metadata.source_uri
except Exception:
_annotation_cache[cache_key] = []
return []
csv_path = _get_csv_path_from_source_uri(source_uri)
if csv_path is None:
_annotation_cache[cache_key] = []
return []
try:
bucket, key = csv_path.split("/", 1)
resp = _get_boto_client().get_object(Bucket=bucket, Key=key)
raw = resp["Body"].read().decode("utf-8")
except Exception:
_annotation_cache[cache_key] = []
return []
annotations = _parse_tuar_csv(raw)
_annotation_cache[cache_key] = annotations
return annotations
def preload_all_annotations(df: pd.DataFrame) -> None:
"""Bulk-fetch all CSV annotation files in one batch S3 call."""
paths_map: dict[str, str] = {}
for _, row in df.iterrows():
canonical_uri = row.get("canonical_uri", "")
archival_uri = row.get("archival_uri", "")
if not canonical_uri or canonical_uri in _annotation_cache:
continue
csv_path = _get_csv_path_from_source_uri(archival_uri)
if csv_path:
paths_map[csv_path] = canonical_uri
if not paths_map:
return
client = _get_boto_client()
for csv_path, canonical_uri in paths_map.items():
try:
bucket, key = csv_path.split("/", 1)
resp = client.get_object(Bucket=bucket, Key=key)
raw = resp["Body"].read().decode("utf-8")
_annotation_cache[canonical_uri] = _parse_tuar_csv(raw)
except Exception:
_annotation_cache[canonical_uri] = []
def _parse_tuar_csv(raw: str) -> list[dict]:
lines = [l for l in raw.strip().split("\n") if not l.startswith("#") and l.strip()]
if not lines:
return []
header_line = lines[0]
if "channel" in header_line and "start_time" in header_line:
lines = lines[1:]
seen = set()
annotations = []
for line in lines:
parts = line.strip().split(",")
if len(parts) < 4:
continue
channel = parts[0].strip()
try:
start = float(parts[1].strip())
stop = float(parts[2].strip())
except ValueError:
continue
raw_label = parts[3].strip().lower()
confidence = float(parts[4].strip()) if len(parts) > 4 else 1.0
dedup_key = (round(start, 3), round(stop, 3), raw_label)
if dedup_key in seen:
continue
seen.add(dedup_key)
label = ARTIFACT_LABEL_MAP.get(raw_label, raw_label.title())
color = ARTIFACT_COLORS.get(label, "rgba(128, 128, 128, 0.2)")
annotations.append({
"onset_s": start,
"duration_s": stop - start,
"end_s": stop,
"raw_label": raw_label,
"label": label,
"color": color,
"channel": channel,
"confidence": confidence,
})
annotations.sort(key=lambda a: a["onset_s"])
return annotations
def get_annotations_in_window(
canonical_uri: str,
start_s: float,
end_s: float,
source_uri: Optional[str] = None,
) -> list[dict]:
all_ann = get_annotations(canonical_uri, source_uri)
return [a for a in all_ann if a["end_s"] > start_s and a["onset_s"] < end_s]
def get_recording_info(row: pd.Series) -> dict:
channel_names = row.get("channel_names", [])
if isinstance(channel_names, str):
try:
channel_names = json.loads(channel_names)
except (json.JSONDecodeError, TypeError):
channel_names = []
if not isinstance(channel_names, list):
channel_names = list(channel_names)
return {
"recording_id": row["recording_id"],
"dataset_id": row["dataset_id"],
"subject": row.get("subject_id_in_dataset", "N/A"),
"session": row.get("session_id", "N/A"),
"task": row.get("task", "N/A"),
"duration_s": row.get("duration_s", 0),
"n_channels": row.get("n_channels", 0),
"n_eeg_channels": row.get("n_eeg_channels", 0),
"sampling_rate_hz": row.get("sampling_rate_hz", 0),
"reference": row.get("reference", "N/A"),
"montage_name": row.get("montage_name", "N/A"),
"recording_type": row.get("recording_type", "N/A"),
"archival_format": row.get("archival_format", "N/A"),
"canonical_uri": row.get("canonical_uri", ""),
"archival_uri": row.get("archival_uri", ""),
"channel_names": channel_names,
"roundtrip_class": row.get("roundtrip_class", "N/A"),
}