Spaces:
Sleeping
Sleeping
| import io | |
| import json | |
| import os | |
| import threading | |
| import warnings | |
| from functools import lru_cache | |
| from typing import Optional | |
| import boto3 | |
| import numpy as np | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| load_dotenv(os.path.join(os.path.dirname(__file__), ".env")) | |
| warnings.filterwarnings("ignore", category=UserWarning, message=".*asynchronous.*") | |
| MANIFEST_DATASET = "tankalapavankalyan/eeg-corpus-manifest" | |
| PARQUET_BASE = ( | |
| "https://huggingface.co/datasets/tankalapavankalyan/eeg-corpus-manifest" | |
| "/resolve/refs%2Fconvert%2Fparquet" | |
| ) | |
| TUAR_DATASET_IDS = ["tuh_eeg_artifact"] | |
| ARTIFACT_LABEL_MAP = { | |
| "eyem": "Eye Movement", | |
| "musc": "Muscle", | |
| "elec": "Electrode Pop", | |
| "chew": "Chewing", | |
| "shiv": "Shiver", | |
| "null": "Clean", | |
| "elpp": "Electrode Pop", | |
| "artf": "Artifact (Generic)", | |
| "bckg": "Background", | |
| "eyem_musc": "Eye Movement + Muscle", | |
| "musc_elec": "Muscle + Electrode Pop", | |
| "eyem_elec": "Eye Movement + Electrode Pop", | |
| "eyem_chew": "Eye Movement + Chewing", | |
| "chew_elec": "Chewing + Electrode Pop", | |
| "chew_musc": "Chewing + Muscle", | |
| "eyem_shiv": "Eye Movement + Shiver", | |
| "shiv_elec": "Shiver + Electrode Pop", | |
| } | |
| ARTIFACT_COLORS = { | |
| "Eye Movement": "rgba(30, 144, 255, 0.25)", | |
| "Muscle": "rgba(220, 20, 60, 0.25)", | |
| "Electrode Pop": "rgba(255, 165, 0, 0.25)", | |
| "Chewing": "rgba(50, 205, 50, 0.25)", | |
| "Shiver": "rgba(148, 103, 189, 0.25)", | |
| "Artifact (Generic)": "rgba(128, 128, 128, 0.25)", | |
| "Background": "rgba(200, 200, 200, 0.08)", | |
| "Clean": "rgba(200, 200, 200, 0.08)", | |
| "Eye Movement + Muscle": "rgba(125, 82, 158, 0.25)", | |
| "Muscle + Electrode Pop": "rgba(238, 93, 30, 0.25)", | |
| "Eye Movement + Electrode Pop": "rgba(143, 155, 128, 0.25)", | |
| "Eye Movement + Chewing": "rgba(40, 175, 153, 0.25)", | |
| "Chewing + Electrode Pop": "rgba(153, 185, 30, 0.25)", | |
| "Chewing + Muscle": "rgba(135, 113, 56, 0.25)", | |
| "Eye Movement + Shiver": "rgba(89, 124, 222, 0.25)", | |
| "Shiver + Electrode Pop": "rgba(202, 134, 95, 0.25)", | |
| } | |
| _boto_client = None | |
| _thread_local = threading.local() | |
| def _get_boto_client(): | |
| global _boto_client | |
| if _boto_client is None: | |
| key = os.environ.get("AWS_ACCESS_KEY_ID") | |
| secret = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
| region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") | |
| if key and secret: | |
| _boto_client = boto3.client("s3", aws_access_key_id=key, aws_secret_access_key=secret, region_name=region) | |
| else: | |
| _boto_client = boto3.client("s3", region_name=region) | |
| return _boto_client | |
| def _get_thread_s3fs(): | |
| """Create a per-thread s3fs instance to avoid async event loop conflicts.""" | |
| import s3fs | |
| fs = getattr(_thread_local, "fs", None) | |
| if fs is None: | |
| key = os.environ.get("AWS_ACCESS_KEY_ID") | |
| secret = os.environ.get("AWS_SECRET_ACCESS_KEY") | |
| region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") | |
| if key and secret: | |
| fs = s3fs.S3FileSystem(key=key, secret=secret, client_kwargs={"region_name": region}, asynchronous=False) | |
| else: | |
| fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": region}, asynchronous=False) | |
| _thread_local.fs = fs | |
| return fs | |
| def reset_s3fs(): | |
| global _boto_client | |
| _boto_client = None | |
| _thread_local.__dict__.clear() | |
| _zarr_cache.clear() | |
| _scale_cache.clear() | |
| _annotation_cache.clear() | |
| MANIFEST_COLUMNS = [ | |
| "recording_id", "dataset_id", "subject_id_in_dataset", "session_id", | |
| "run_id", "task", "archival_uri", "archival_format", "duration_s", | |
| "n_channels", "n_eeg_channels", "sampling_rate_hz", "reference", | |
| "montage_name", "recording_type", "channel_names", "canonical_uri", | |
| "conversion_status", "roundtrip_class", | |
| ] | |
| def get_tuar_recordings() -> pd.DataFrame: | |
| url = f"{PARQUET_BASE}/recordings/train/0000.parquet" | |
| df = pd.read_parquet(url, columns=MANIFEST_COLUMNS) | |
| mask = df["dataset_id"].isin(TUAR_DATASET_IDS) & (df["conversion_status"] == "ok") | |
| result = df[mask].copy() | |
| del df | |
| result = result.sort_values("subject_id_in_dataset").reset_index(drop=True) | |
| return result | |
| def get_recording_display_list(df: pd.DataFrame) -> list[str]: | |
| entries = [] | |
| for _, row in df.iterrows(): | |
| label = ( | |
| f"{row['recording_id'][:8]}… | " | |
| f"{row['dataset_id']} | " | |
| f"subj={row['subject_id_in_dataset']} | " | |
| f"ses={row.get('session_id', 'N/A')} | " | |
| f"dur={row['duration_s']:.0f}s | " | |
| f"{row['n_channels']:.0f}ch @ {row['sampling_rate_hz']:.0f}Hz" | |
| ) | |
| entries.append(label) | |
| return entries | |
| # --- Signal access via direct Zarr + S3 --- | |
| _zarr_cache: dict[str, object] = {} | |
| _scale_cache: dict[str, tuple[np.ndarray, np.ndarray]] = {} | |
| def _register_flac_codec(): | |
| """Register the FLAC codec with zarr v3 if not already done.""" | |
| try: | |
| from zarr.registry import get_codec_class | |
| get_codec_class("numcodecs.flac") | |
| except KeyError: | |
| import numcodecs | |
| from flac_numcodecs import Flac | |
| numcodecs.register_codec(Flac) | |
| from zarr.codecs.numcodecs._codecs import _NumcodecsArrayBytesCodec | |
| from zarr.codecs import register_codec | |
| class FlacCodec(_NumcodecsArrayBytesCodec): | |
| codec_name = "numcodecs.flac" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| register_codec("numcodecs.flac", FlacCodec) | |
| def _open_zarr(canonical_uri: str): | |
| if canonical_uri in _zarr_cache: | |
| return _zarr_cache[canonical_uri] | |
| _register_flac_codec() | |
| fs = _get_thread_s3fs() | |
| s3_path = canonical_uri.replace("s3://", "") | |
| if not s3_path.endswith("/"): | |
| s3_path += "/" | |
| import zarr | |
| fsspec_store = zarr.storage.FsspecStore(fs=fs, path=s3_path, read_only=True) | |
| root = zarr.open_group(fsspec_store, mode="r") | |
| _zarr_cache[canonical_uri] = root | |
| if len(_zarr_cache) > 50: | |
| oldest = next(iter(_zarr_cache)) | |
| del _zarr_cache[oldest] | |
| _scale_cache.pop(oldest, None) | |
| return root | |
| def _get_scale_offset(canonical_uri: str): | |
| if canonical_uri in _scale_cache: | |
| return _scale_cache[canonical_uri] | |
| root = _open_zarr(canonical_uri) | |
| ch_grp = root["channels"] | |
| phys_min = np.array(ch_grp["physical_min"][:], dtype=np.float64) | |
| phys_max = np.array(ch_grp["physical_max"][:], dtype=np.float64) | |
| dig_min = np.array(ch_grp["digital_min"][:], dtype=np.float64) | |
| dig_max = np.array(ch_grp["digital_max"][:], dtype=np.float64) | |
| scale = (phys_max - phys_min) / (dig_max - dig_min + 1e-12) | |
| offset = phys_min - dig_min * scale | |
| _scale_cache[canonical_uri] = (scale.astype(np.float32), offset.astype(np.float32)) | |
| return _scale_cache[canonical_uri] | |
| def read_signal_window( | |
| canonical_uri: str, | |
| start_sample: int, | |
| end_sample: int, | |
| channel_indices: Optional[list[int]] = None, | |
| ) -> np.ndarray: | |
| root = _open_zarr(canonical_uri) | |
| sig_arr = root["signal"] | |
| if channel_indices is not None: | |
| raw = np.array(sig_arr[channel_indices, start_sample:end_sample], dtype=np.float32) | |
| else: | |
| raw = np.array(sig_arr[:, start_sample:end_sample], dtype=np.float32) | |
| scale, offset = _get_scale_offset(canonical_uri) | |
| if channel_indices is not None: | |
| scale = scale[channel_indices] | |
| offset = offset[channel_indices] | |
| data = raw * scale[:, None] + offset[:, None] | |
| kernel_size = 5 | |
| if data.shape[1] > kernel_size: | |
| kernel = np.ones(kernel_size) / kernel_size | |
| for i in range(data.shape[0]): | |
| data[i] = np.convolve(data[i], kernel, mode="same") | |
| return data | |
| def get_channel_names(canonical_uri: str) -> list[str]: | |
| root = _open_zarr(canonical_uri) | |
| return list(root["channels"]["name"][:]) | |
| def get_store_metadata(canonical_uri: str) -> dict: | |
| root = _open_zarr(canonical_uri) | |
| attrs = dict(root.attrs) | |
| return { | |
| "n_channels": root["signal"].shape[0], | |
| "n_samples": root["signal"].shape[1], | |
| "sampling_rate_hz": attrs.get("sampling_rate_hz"), | |
| "duration_s": attrs.get("duration_s"), | |
| "channel_names": list(root["channels"]["name"][:]), | |
| "reference": attrs.get("reference"), | |
| "montage_name": attrs.get("montage_name"), | |
| "recording_type": attrs.get("recording_type"), | |
| "manufacturer": attrs.get("manufacturer"), | |
| "source_uri": attrs.get("source_uri"), | |
| } | |
| # --- TUAR artifact annotations from CSV companion files --- | |
| _annotation_cache: dict[str, list[dict]] = {} | |
| def _get_csv_path_from_source_uri(source_uri: str) -> Optional[str]: | |
| if not source_uri or not source_uri.endswith(".edf"): | |
| return None | |
| return source_uri.replace("s3://", "").rsplit(".", 1)[0] + ".csv" | |
| def get_annotations(canonical_uri: str, source_uri: Optional[str] = None) -> list[dict]: | |
| cache_key = canonical_uri | |
| if cache_key in _annotation_cache: | |
| return _annotation_cache[cache_key] | |
| if source_uri is None: | |
| try: | |
| rec = _open_recording(canonical_uri) | |
| source_uri = rec.metadata.source_uri | |
| except Exception: | |
| _annotation_cache[cache_key] = [] | |
| return [] | |
| csv_path = _get_csv_path_from_source_uri(source_uri) | |
| if csv_path is None: | |
| _annotation_cache[cache_key] = [] | |
| return [] | |
| try: | |
| bucket, key = csv_path.split("/", 1) | |
| resp = _get_boto_client().get_object(Bucket=bucket, Key=key) | |
| raw = resp["Body"].read().decode("utf-8") | |
| except Exception: | |
| _annotation_cache[cache_key] = [] | |
| return [] | |
| annotations = _parse_tuar_csv(raw) | |
| _annotation_cache[cache_key] = annotations | |
| return annotations | |
| def preload_all_annotations(df: pd.DataFrame) -> None: | |
| """Bulk-fetch all CSV annotation files in one batch S3 call.""" | |
| paths_map: dict[str, str] = {} | |
| for _, row in df.iterrows(): | |
| canonical_uri = row.get("canonical_uri", "") | |
| archival_uri = row.get("archival_uri", "") | |
| if not canonical_uri or canonical_uri in _annotation_cache: | |
| continue | |
| csv_path = _get_csv_path_from_source_uri(archival_uri) | |
| if csv_path: | |
| paths_map[csv_path] = canonical_uri | |
| if not paths_map: | |
| return | |
| client = _get_boto_client() | |
| for csv_path, canonical_uri in paths_map.items(): | |
| try: | |
| bucket, key = csv_path.split("/", 1) | |
| resp = client.get_object(Bucket=bucket, Key=key) | |
| raw = resp["Body"].read().decode("utf-8") | |
| _annotation_cache[canonical_uri] = _parse_tuar_csv(raw) | |
| except Exception: | |
| _annotation_cache[canonical_uri] = [] | |
| def _parse_tuar_csv(raw: str) -> list[dict]: | |
| lines = [l for l in raw.strip().split("\n") if not l.startswith("#") and l.strip()] | |
| if not lines: | |
| return [] | |
| header_line = lines[0] | |
| if "channel" in header_line and "start_time" in header_line: | |
| lines = lines[1:] | |
| seen = set() | |
| annotations = [] | |
| for line in lines: | |
| parts = line.strip().split(",") | |
| if len(parts) < 4: | |
| continue | |
| channel = parts[0].strip() | |
| try: | |
| start = float(parts[1].strip()) | |
| stop = float(parts[2].strip()) | |
| except ValueError: | |
| continue | |
| raw_label = parts[3].strip().lower() | |
| confidence = float(parts[4].strip()) if len(parts) > 4 else 1.0 | |
| dedup_key = (round(start, 3), round(stop, 3), raw_label) | |
| if dedup_key in seen: | |
| continue | |
| seen.add(dedup_key) | |
| label = ARTIFACT_LABEL_MAP.get(raw_label, raw_label.title()) | |
| color = ARTIFACT_COLORS.get(label, "rgba(128, 128, 128, 0.2)") | |
| annotations.append({ | |
| "onset_s": start, | |
| "duration_s": stop - start, | |
| "end_s": stop, | |
| "raw_label": raw_label, | |
| "label": label, | |
| "color": color, | |
| "channel": channel, | |
| "confidence": confidence, | |
| }) | |
| annotations.sort(key=lambda a: a["onset_s"]) | |
| return annotations | |
| def get_annotations_in_window( | |
| canonical_uri: str, | |
| start_s: float, | |
| end_s: float, | |
| source_uri: Optional[str] = None, | |
| ) -> list[dict]: | |
| all_ann = get_annotations(canonical_uri, source_uri) | |
| return [a for a in all_ann if a["end_s"] > start_s and a["onset_s"] < end_s] | |
| def get_recording_info(row: pd.Series) -> dict: | |
| channel_names = row.get("channel_names", []) | |
| if isinstance(channel_names, str): | |
| try: | |
| channel_names = json.loads(channel_names) | |
| except (json.JSONDecodeError, TypeError): | |
| channel_names = [] | |
| if not isinstance(channel_names, list): | |
| channel_names = list(channel_names) | |
| return { | |
| "recording_id": row["recording_id"], | |
| "dataset_id": row["dataset_id"], | |
| "subject": row.get("subject_id_in_dataset", "N/A"), | |
| "session": row.get("session_id", "N/A"), | |
| "task": row.get("task", "N/A"), | |
| "duration_s": row.get("duration_s", 0), | |
| "n_channels": row.get("n_channels", 0), | |
| "n_eeg_channels": row.get("n_eeg_channels", 0), | |
| "sampling_rate_hz": row.get("sampling_rate_hz", 0), | |
| "reference": row.get("reference", "N/A"), | |
| "montage_name": row.get("montage_name", "N/A"), | |
| "recording_type": row.get("recording_type", "N/A"), | |
| "archival_format": row.get("archival_format", "N/A"), | |
| "canonical_uri": row.get("canonical_uri", ""), | |
| "archival_uri": row.get("archival_uri", ""), | |
| "channel_names": channel_names, | |
| "roundtrip_class": row.get("roundtrip_class", "N/A"), | |
| } | |