Spaces:
Sleeping
Sleeping
Commit ·
b11ec91
1
Parent(s): 1442b78
[add] app files
Browse files- requirements.txt +7 -3
- src/loader.py +393 -0
- src/preprocess.py +59 -0
- src/streamlit_app.py +853 -37
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
numpy
|
| 3 |
+
plotly
|
| 4 |
+
scipy
|
| 5 |
+
mne
|
| 6 |
+
h5py
|
| 7 |
+
networkx
|
src/loader.py
ADDED
|
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List, Tuple, Optional, Dict, Any
|
| 4 |
+
import io
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import mne
|
| 11 |
+
from scipy.io import loadmat
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import h5py # MAT v7.3 (HDF5)
|
| 15 |
+
except Exception: # pragma: no cover
|
| 16 |
+
h5py = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ============================================================
|
| 20 |
+
# EEGLAB loader (.set + .fdt)
|
| 21 |
+
# ============================================================
|
| 22 |
+
def pick_set_fdt(files) -> Tuple[Optional[object], Optional[object]]:
|
| 23 |
+
"""
|
| 24 |
+
Streamlitの accept_multiple_files=True で受け取ったfilesから .set と .fdt を拾う。
|
| 25 |
+
Returns: (set_file, fdt_file)
|
| 26 |
+
"""
|
| 27 |
+
set_file = None
|
| 28 |
+
fdt_file = None
|
| 29 |
+
for f in files:
|
| 30 |
+
name = (getattr(f, "name", "") or "").lower()
|
| 31 |
+
if name.endswith(".set"):
|
| 32 |
+
set_file = f
|
| 33 |
+
elif name.endswith(".fdt"):
|
| 34 |
+
fdt_file = f
|
| 35 |
+
return set_file, fdt_file
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def same_stem(a_name: str, b_name: str) -> bool:
|
| 39 |
+
"""Check if two filenames have the same stem (basename without extension)."""
|
| 40 |
+
a_stem = os.path.splitext(os.path.basename(a_name))[0]
|
| 41 |
+
b_stem = os.path.splitext(os.path.basename(b_name))[0]
|
| 42 |
+
return a_stem == b_stem
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _load_eeglab_hdf5(set_path: str, fdt_path: Optional[str] = None, debug: bool = False) -> Tuple[np.ndarray, float]:
|
| 46 |
+
"""
|
| 47 |
+
Load EEGLAB .set file saved in MATLAB v7.3 (HDF5) format using h5py.
|
| 48 |
+
Returns: (x_tc, fs) where x_tc is (T, C)
|
| 49 |
+
"""
|
| 50 |
+
if h5py is None:
|
| 51 |
+
raise RuntimeError("EEGLAB .set ファイルが MATLAB v7.3 (HDF5) 形式ですが、h5py がインストールされていません。pip install h5py を実行してください。")
|
| 52 |
+
|
| 53 |
+
with h5py.File(set_path, "r") as f:
|
| 54 |
+
# デバッグ: ファイル構造を表示
|
| 55 |
+
if debug:
|
| 56 |
+
print("=== HDF5 file structure ===")
|
| 57 |
+
def print_structure(name, obj):
|
| 58 |
+
if isinstance(obj, h5py.Dataset):
|
| 59 |
+
print(f"Dataset: {name}, shape: {obj.shape}, dtype: {obj.dtype}")
|
| 60 |
+
elif isinstance(obj, h5py.Group):
|
| 61 |
+
print(f"Group: {name}")
|
| 62 |
+
f.visititems(print_structure)
|
| 63 |
+
print("===========================")
|
| 64 |
+
|
| 65 |
+
# サンプリングレートを取得
|
| 66 |
+
fs = None
|
| 67 |
+
for path in ["EEG/srate", "srate"]:
|
| 68 |
+
if path in f:
|
| 69 |
+
srate_data = f[path]
|
| 70 |
+
if isinstance(srate_data, h5py.Dataset):
|
| 71 |
+
val = srate_data[()]
|
| 72 |
+
# 配列の場合は最初の要素を取得
|
| 73 |
+
fs = float(val.flat[0]) if hasattr(val, 'flat') else float(val)
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
if fs is None:
|
| 77 |
+
raise ValueError("サンプリングレート (srate) が見つかりません")
|
| 78 |
+
|
| 79 |
+
# チャンネル数を取得
|
| 80 |
+
nbchan = None
|
| 81 |
+
for path in ["EEG/nbchan", "nbchan"]:
|
| 82 |
+
if path in f:
|
| 83 |
+
nbchan_data = f[path]
|
| 84 |
+
if isinstance(nbchan_data, h5py.Dataset):
|
| 85 |
+
val = nbchan_data[()]
|
| 86 |
+
nbchan = int(val.flat[0]) if hasattr(val, 'flat') else int(val)
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
# サンプル数を取得
|
| 90 |
+
pnts = None
|
| 91 |
+
for path in ["EEG/pnts", "pnts"]:
|
| 92 |
+
if path in f:
|
| 93 |
+
pnts_data = f[path]
|
| 94 |
+
if isinstance(pnts_data, h5py.Dataset):
|
| 95 |
+
val = pnts_data[()]
|
| 96 |
+
pnts = int(val.flat[0]) if hasattr(val, 'flat') else int(val)
|
| 97 |
+
break
|
| 98 |
+
|
| 99 |
+
if debug:
|
| 100 |
+
print(f"nbchan: {nbchan}, pnts: {pnts}, fs: {fs}")
|
| 101 |
+
|
| 102 |
+
# データを取得 - まず .set 内を確認
|
| 103 |
+
data = None
|
| 104 |
+
data_shape = None
|
| 105 |
+
|
| 106 |
+
if debug:
|
| 107 |
+
print(f"Checking for data, fdt_path provided: {fdt_path is not None}")
|
| 108 |
+
if fdt_path:
|
| 109 |
+
print(f"fdt_path exists: {os.path.exists(fdt_path)}")
|
| 110 |
+
|
| 111 |
+
# パターン1: EEG/data が参照配列の場合、各参照を辿る
|
| 112 |
+
if "EEG" in f and "data" in f["EEG"]:
|
| 113 |
+
data_ref = f["EEG"]["data"]
|
| 114 |
+
if isinstance(data_ref, h5py.Dataset):
|
| 115 |
+
if debug:
|
| 116 |
+
print(f"EEG/data dtype: {data_ref.dtype}, shape: {data_ref.shape}, size: {data_ref.size}")
|
| 117 |
+
|
| 118 |
+
if data_ref.dtype == h5py.ref_dtype:
|
| 119 |
+
# 参照の場合 - 通常は .fdt ファイルを指す
|
| 120 |
+
if debug:
|
| 121 |
+
print("EEG/data is reference type - data should be in .fdt file")
|
| 122 |
+
# .fdt ファイルが必要
|
| 123 |
+
if fdt_path is not None and os.path.exists(fdt_path):
|
| 124 |
+
data = _load_fdt_file(fdt_path, nbchan, pnts, debug=debug)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(".fdt ファイルが必要ですが見つかりません。.set と .fdt の両方をアップロードしてください。")
|
| 127 |
+
elif data_ref.size > 100: # 参照配列ではなく実データ
|
| 128 |
+
data = data_ref[()]
|
| 129 |
+
data_shape = data.shape
|
| 130 |
+
if debug:
|
| 131 |
+
print(f"EEG/data contains actual data, shape: {data_shape}")
|
| 132 |
+
else:
|
| 133 |
+
# 小さい配列 = 参照リスト、.fdtファイルが必要
|
| 134 |
+
if debug:
|
| 135 |
+
print(f"EEG/data is small array (size={data_ref.size}), assuming reference to .fdt")
|
| 136 |
+
if fdt_path is not None and os.path.exists(fdt_path):
|
| 137 |
+
data = _load_fdt_file(fdt_path, nbchan, pnts, debug=debug)
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(".fdt ファイルが必要ですが見つかりません。.set と .fdt の両方をアップロードしてください。")
|
| 140 |
+
|
| 141 |
+
# パターン2: 直接 data
|
| 142 |
+
if data is None and "data" in f:
|
| 143 |
+
data_obj = f["data"]
|
| 144 |
+
if isinstance(data_obj, h5py.Dataset):
|
| 145 |
+
data = data_obj[()]
|
| 146 |
+
data_shape = data.shape
|
| 147 |
+
|
| 148 |
+
if data is None:
|
| 149 |
+
raise ValueError("EEGデータが見つかりません。.fdt ファイルが必要な可能性があります。")
|
| 150 |
+
|
| 151 |
+
if debug:
|
| 152 |
+
print(f"Data shape: {data.shape if hasattr(data, 'shape') else 'loaded from fdt'}")
|
| 153 |
+
|
| 154 |
+
# データの形状を調整
|
| 155 |
+
if data.ndim != 2:
|
| 156 |
+
raise ValueError(f"予期しないデータ次元: {data.ndim}")
|
| 157 |
+
|
| 158 |
+
dim0, dim1 = data.shape
|
| 159 |
+
|
| 160 |
+
# nbchan情報があればそれを使う
|
| 161 |
+
if nbchan is not None:
|
| 162 |
+
if dim0 == nbchan:
|
| 163 |
+
# (C, T) 形式
|
| 164 |
+
x_tc = data.T.astype(np.float32)
|
| 165 |
+
elif dim1 == nbchan:
|
| 166 |
+
# (T, C) 形式
|
| 167 |
+
x_tc = data.astype(np.float32)
|
| 168 |
+
else:
|
| 169 |
+
# nbchanと一致しない場合は小さい方をチャンネル数と仮定
|
| 170 |
+
if dim0 < dim1:
|
| 171 |
+
x_tc = data.T.astype(np.float32)
|
| 172 |
+
else:
|
| 173 |
+
x_tc = data.astype(np.float32)
|
| 174 |
+
else:
|
| 175 |
+
# 一般的な判定: 小さい方がチャンネル数
|
| 176 |
+
if dim0 < dim1:
|
| 177 |
+
x_tc = data.T.astype(np.float32)
|
| 178 |
+
else:
|
| 179 |
+
x_tc = data.astype(np.float32)
|
| 180 |
+
|
| 181 |
+
if debug:
|
| 182 |
+
print(f"Final shape (T, C): {x_tc.shape}")
|
| 183 |
+
|
| 184 |
+
return x_tc, fs
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _load_fdt_file(fdt_path: str, nbchan: Optional[int], pnts: Optional[int], debug: bool = False) -> np.ndarray:
|
| 188 |
+
"""
|
| 189 |
+
Load .fdt file (raw binary float32 data).
|
| 190 |
+
EEGLAB .fdt files are stored as float32 in (C, T) order.
|
| 191 |
+
"""
|
| 192 |
+
if debug:
|
| 193 |
+
print(f"Loading .fdt file: {fdt_path}")
|
| 194 |
+
|
| 195 |
+
# .fdt ファイルは float32 のバイナリデータ
|
| 196 |
+
data = np.fromfile(fdt_path, dtype=np.float32)
|
| 197 |
+
|
| 198 |
+
if debug:
|
| 199 |
+
print(f"Loaded {data.size} float32 values from .fdt")
|
| 200 |
+
|
| 201 |
+
# チャンネル数とサンプル数がわかっている場合はリシェイプ
|
| 202 |
+
if nbchan is not None and pnts is not None:
|
| 203 |
+
expected_size = nbchan * pnts
|
| 204 |
+
if data.size == expected_size:
|
| 205 |
+
# EEGLAB は (C, T) 順で保存
|
| 206 |
+
data = data.reshape(nbchan, pnts)
|
| 207 |
+
if debug:
|
| 208 |
+
print(f"Reshaped to ({nbchan}, {pnts})")
|
| 209 |
+
else:
|
| 210 |
+
if debug:
|
| 211 |
+
print(f"Warning: expected {expected_size} values but got {data.size}")
|
| 212 |
+
# 可能な限りリシェイプを試みる
|
| 213 |
+
if data.size % nbchan == 0:
|
| 214 |
+
data = data.reshape(nbchan, -1)
|
| 215 |
+
elif data.size % pnts == 0:
|
| 216 |
+
data = data.reshape(-1, pnts)
|
| 217 |
+
else:
|
| 218 |
+
raise ValueError(f"Cannot reshape data of size {data.size} with nbchan={nbchan}, pnts={pnts}")
|
| 219 |
+
else:
|
| 220 |
+
raise ValueError("nbchan と pnts の情報が必要です")
|
| 221 |
+
|
| 222 |
+
return data
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def load_eeglab_tc_from_bytes(
|
| 226 |
+
set_bytes: bytes,
|
| 227 |
+
set_name: str,
|
| 228 |
+
fdt_bytes: Optional[bytes] = None,
|
| 229 |
+
fdt_name: Optional[str] = None,
|
| 230 |
+
) -> Tuple[np.ndarray, float]:
|
| 231 |
+
"""
|
| 232 |
+
Load EEGLAB .set (and optional .fdt) from bytes using MNE or h5py.
|
| 233 |
+
Returns:
|
| 234 |
+
x_tc: (T, C) float32
|
| 235 |
+
fs: sampling rate (Hz)
|
| 236 |
+
|
| 237 |
+
Notes:
|
| 238 |
+
- 多くのEEGLABは .set が .fdt を参照するため、同じディレクトリに同名で置く必要があります。
|
| 239 |
+
- .set単体で完結している場合は fdt_* を省略可能にしています。
|
| 240 |
+
- MATLAB v7.3 (HDF5) 形式の .set にも対応しています。
|
| 241 |
+
"""
|
| 242 |
+
if fdt_bytes is not None or fdt_name is not None:
|
| 243 |
+
if fdt_bytes is None or fdt_name is None:
|
| 244 |
+
raise ValueError("fdt_bytes と fdt_name は両方指定してください。")
|
| 245 |
+
if not same_stem(set_name, fdt_name):
|
| 246 |
+
raise ValueError(f".set と .fdt のファイル名(拡張子除く)が一致していません: {set_name} vs {fdt_name}")
|
| 247 |
+
|
| 248 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 249 |
+
set_path = os.path.join(tmpdir, os.path.basename(set_name))
|
| 250 |
+
with open(set_path, "wb") as f:
|
| 251 |
+
f.write(set_bytes)
|
| 252 |
+
|
| 253 |
+
fdt_path = None # 初期化
|
| 254 |
+
if fdt_bytes is not None and fdt_name is not None:
|
| 255 |
+
fdt_path = os.path.join(tmpdir, os.path.basename(fdt_name))
|
| 256 |
+
with open(fdt_path, "wb") as f:
|
| 257 |
+
f.write(fdt_bytes)
|
| 258 |
+
|
| 259 |
+
# 1) Rawとして読む(通常のEEGLAB形式)
|
| 260 |
+
try:
|
| 261 |
+
raw = mne.io.read_raw_eeglab(set_path, preload=True, verbose=False)
|
| 262 |
+
fs = float(raw.info["sfreq"])
|
| 263 |
+
x_tc = raw.get_data().T # (T,C)
|
| 264 |
+
return x_tc.astype(np.float32), fs
|
| 265 |
+
|
| 266 |
+
except Exception as e_raw:
|
| 267 |
+
# 2) Epochsとして読む(エポックデータ用)
|
| 268 |
+
try:
|
| 269 |
+
epochs = mne.io.read_epochs_eeglab(set_path, verbose=False, montage_units="cm")
|
| 270 |
+
fs = float(epochs.info["sfreq"])
|
| 271 |
+
x = epochs.get_data(copy=True) # (n_epochs, n_channels, n_times)
|
| 272 |
+
|
| 273 |
+
# ここは方針を選ぶ:平均 or 連結
|
| 274 |
+
x_mean = x.mean(axis=0) # (C,T)
|
| 275 |
+
x_tc = x_mean.T # (T,C)
|
| 276 |
+
return x_tc.astype(np.float32), fs
|
| 277 |
+
|
| 278 |
+
except Exception as e_ep:
|
| 279 |
+
# 3) HDF5形式として読む(MATLAB v7.3)
|
| 280 |
+
try:
|
| 281 |
+
# デバッグモードを有効化(環境変数で制御可能)
|
| 282 |
+
debug = os.environ.get("EEGLAB_DEBUG", "0") == "1"
|
| 283 |
+
# Streamlit環境では常にデバッグ情報を表示
|
| 284 |
+
import sys
|
| 285 |
+
if 'streamlit' in sys.modules:
|
| 286 |
+
debug = True
|
| 287 |
+
x_tc, fs = _load_eeglab_hdf5(set_path, fdt_path=fdt_path, debug=debug)
|
| 288 |
+
return x_tc, fs
|
| 289 |
+
|
| 290 |
+
except Exception as e_hdf5:
|
| 291 |
+
# すべて失敗した場合
|
| 292 |
+
msg = (
|
| 293 |
+
"EEGLABの読み込みに失敗しました。\n"
|
| 294 |
+
f"- read_raw_eeglab error: {e_raw}\n"
|
| 295 |
+
f"- read_epochs_eeglab error: {e_ep}\n"
|
| 296 |
+
f"- HDF5読み込み error: {e_hdf5}\n"
|
| 297 |
+
)
|
| 298 |
+
raise RuntimeError(msg) from e_hdf5
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# ============================================================
|
| 303 |
+
# MAT loader (.mat)
|
| 304 |
+
# ============================================================
|
| 305 |
+
def _mat_keys_loadmat(mat_dict: Dict[str, Any]) -> List[str]:
|
| 306 |
+
return sorted([k for k in mat_dict.keys() if not k.startswith("__")])
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _try_get_numeric_arrays_loadmat(mat_dict: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
| 310 |
+
"""
|
| 311 |
+
loadmatで読んだdictから、1D/2Dの数値ndarrayだけ抽出して返す。
|
| 312 |
+
3次元配列も含める(エポックデータの可能性)。
|
| 313 |
+
"""
|
| 314 |
+
out: Dict[str, np.ndarray] = {}
|
| 315 |
+
for k in _mat_keys_loadmat(mat_dict):
|
| 316 |
+
v = mat_dict[k]
|
| 317 |
+
if isinstance(v, np.ndarray) and v.size > 0:
|
| 318 |
+
# 数値型かどうかチェック
|
| 319 |
+
if np.issubdtype(v.dtype, np.number):
|
| 320 |
+
if v.ndim in (1, 2):
|
| 321 |
+
out[k] = v
|
| 322 |
+
elif v.ndim == 3:
|
| 323 |
+
# 3次元配列の場合は (epochs, channels, time) の可能性
|
| 324 |
+
# 平均を取って2次元にする、または連結する
|
| 325 |
+
out[k + "_mean"] = v.mean(axis=0) # (C, T)
|
| 326 |
+
out[k + "_concat"] = v.reshape(-1, v.shape[-1]) # (epochs*C, T)
|
| 327 |
+
return out
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def _load_mat_v72(bytes_data: bytes) -> Dict[str, Any]:
|
| 331 |
+
# v7.2以前のMAT(一般的なMAT)
|
| 332 |
+
return loadmat(io.BytesIO(bytes_data), squeeze_me=False, struct_as_record=False)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _load_mat_v73_candidates(bytes_data: bytes) -> Dict[str, np.ndarray]:
|
| 336 |
+
"""
|
| 337 |
+
v7.3(HDF5)のMATから、数値1D/2D/3D dataset を拾って返す。
|
| 338 |
+
keyは HDF5内のパスになります(例: 'group/data')。
|
| 339 |
+
|
| 340 |
+
修正: h5pyの新しいバージョンに対応。BytesIOではなく一時ファイルを使用。
|
| 341 |
+
"""
|
| 342 |
+
if h5py is None:
|
| 343 |
+
raise RuntimeError("MAT v7.3(HDF5) 形式の可能性がありますが、h5py が入っていません。pip install h5py を実行してください。")
|
| 344 |
+
|
| 345 |
+
out: Dict[str, np.ndarray] = {}
|
| 346 |
+
|
| 347 |
+
# h5pyの新しいバージョンではBytesIOから直接開けない場合があるため、一時ファイルを使用
|
| 348 |
+
with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp:
|
| 349 |
+
tmp.write(bytes_data)
|
| 350 |
+
tmp_path = tmp.name
|
| 351 |
+
|
| 352 |
+
try:
|
| 353 |
+
with h5py.File(tmp_path, "r") as f:
|
| 354 |
+
|
| 355 |
+
def visitor(name, obj):
|
| 356 |
+
if not isinstance(obj, h5py.Dataset):
|
| 357 |
+
return
|
| 358 |
+
try:
|
| 359 |
+
arr = obj[()]
|
| 360 |
+
except Exception:
|
| 361 |
+
return
|
| 362 |
+
|
| 363 |
+
# MATLABの文字列/参照等は除外して、数値だけ
|
| 364 |
+
if isinstance(arr, np.ndarray) and arr.size > 0 and np.issubdtype(arr.dtype, np.number):
|
| 365 |
+
if arr.ndim in (1, 2):
|
| 366 |
+
out[name] = arr
|
| 367 |
+
elif arr.ndim == 3:
|
| 368 |
+
# 3次元配列も含める
|
| 369 |
+
out[name + "_mean"] = arr.mean(axis=0)
|
| 370 |
+
out[name + "_concat"] = arr.reshape(-1, arr.shape[-1])
|
| 371 |
+
|
| 372 |
+
f.visititems(lambda name, obj: visitor(name, obj))
|
| 373 |
+
finally:
|
| 374 |
+
# 一時ファイルを削除
|
| 375 |
+
try:
|
| 376 |
+
os.unlink(tmp_path)
|
| 377 |
+
except Exception:
|
| 378 |
+
pass
|
| 379 |
+
|
| 380 |
+
return out
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def load_mat_candidates(bytes_data: bytes) -> Dict[str, np.ndarray]:
|
| 384 |
+
"""
|
| 385 |
+
Return dict: variable_name -> ndarray(1D/2D numeric)
|
| 386 |
+
Tries v7.2 (scipy.io.loadmat). If it fails, tries v7.3 (h5py).
|
| 387 |
+
"""
|
| 388 |
+
try:
|
| 389 |
+
md = _load_mat_v72(bytes_data)
|
| 390 |
+
cands = _try_get_numeric_arrays_loadmat(md)
|
| 391 |
+
return cands
|
| 392 |
+
except Exception:
|
| 393 |
+
return _load_mat_v73_candidates(bytes_data)
|
src/preprocess.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# preprocessing.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
import numpy as np
|
| 5 |
+
import mne
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass(frozen=True)
|
| 9 |
+
class PreprocessConfig:
|
| 10 |
+
fs: float
|
| 11 |
+
f_low: float
|
| 12 |
+
f_high: float
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def to_time_channel(x: np.ndarray) -> np.ndarray:
|
| 16 |
+
if x.ndim == 1:
|
| 17 |
+
return x[:, None]
|
| 18 |
+
if x.ndim != 2:
|
| 19 |
+
raise ValueError(f"Expected 1D or 2D array, got {x.shape}")
|
| 20 |
+
T, C = x.shape
|
| 21 |
+
if T <= 256 and C > T:
|
| 22 |
+
x = x.T
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
|
| 27 |
+
info = mne.create_info(
|
| 28 |
+
ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
|
| 29 |
+
sfreq=cfg.fs,
|
| 30 |
+
ch_types="eeg",
|
| 31 |
+
)
|
| 32 |
+
raw = mne.io.RawArray(x_tc.T, info, verbose=False)
|
| 33 |
+
raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False)
|
| 34 |
+
return raw_filt.get_data().T
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
|
| 38 |
+
Xf = np.fft.fft(x_tc, axis=0)
|
| 39 |
+
N = Xf.shape[0]
|
| 40 |
+
h = np.zeros(N)
|
| 41 |
+
if N % 2 == 0:
|
| 42 |
+
h[0] = h[N // 2] = 1
|
| 43 |
+
h[1:N // 2] = 2
|
| 44 |
+
else:
|
| 45 |
+
h[0] = 1
|
| 46 |
+
h[1:(N + 1) // 2] = 2
|
| 47 |
+
env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0))
|
| 48 |
+
return env.astype(np.float32)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig):
|
| 52 |
+
x_tc = to_time_channel(x)
|
| 53 |
+
x_filt = bandpass_tc(x_tc, cfg)
|
| 54 |
+
env = hilbert_envelope_tc(x_filt)
|
| 55 |
+
return {
|
| 56 |
+
"raw": x_tc,
|
| 57 |
+
"filtered": x_filt,
|
| 58 |
+
"envelope": env,
|
| 59 |
+
}
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,856 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
import streamlit as st
|
| 7 |
+
import plotly.graph_objects as go
|
| 8 |
+
import mne
|
| 9 |
+
from scipy.signal import hilbert
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import community as community_louvain
|
| 13 |
+
import networkx as nx
|
| 14 |
+
LOUVAIN_AVAILABLE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
LOUVAIN_AVAILABLE = False
|
| 17 |
+
st.warning("⚠️ Louvainクラスタリングを使用するには `pip install python-louvain networkx` を実行してください。")
|
| 18 |
+
|
| 19 |
+
from loader import (
|
| 20 |
+
pick_set_fdt,
|
| 21 |
+
load_eeglab_tc_from_bytes,
|
| 22 |
+
load_mat_candidates,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ============================================================
|
| 29 |
+
# Preprocess config
|
| 30 |
+
# ============================================================
|
| 31 |
+
@dataclass(frozen=True)
|
| 32 |
+
class PreprocessConfig:
|
| 33 |
+
fs: float
|
| 34 |
+
f_low: float
|
| 35 |
+
f_high: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ============================================================
|
| 39 |
+
# Helpers
|
| 40 |
+
# ============================================================
|
| 41 |
+
def ensure_tc(x: np.ndarray) -> np.ndarray:
|
| 42 |
+
"""Ensure array is (T,C). Accept (T,), (T,C), (C,T) with heuristic transpose."""
|
| 43 |
+
x = np.asarray(x)
|
| 44 |
+
if x.ndim == 1:
|
| 45 |
+
return x[:, None]
|
| 46 |
+
if x.ndim != 2:
|
| 47 |
+
raise ValueError(f"2次元配列のみ対応です: shape={x.shape}")
|
| 48 |
+
T, C = x.shape
|
| 49 |
+
if T <= 256 and C > T: # heuristic transpose
|
| 50 |
+
x = x.T
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ============================================================
|
| 55 |
+
# Signal processing
|
| 56 |
+
# ============================================================
|
| 57 |
+
def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
|
| 58 |
+
"""Bandpass filter each channel using MNE RawArray. Input/Output: (T,C)."""
|
| 59 |
+
info = mne.create_info(
|
| 60 |
+
ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
|
| 61 |
+
sfreq=float(cfg.fs),
|
| 62 |
+
ch_types="eeg",
|
| 63 |
+
)
|
| 64 |
+
raw = mne.io.RawArray(x_tc.T, info, verbose=False) # (C,T)
|
| 65 |
+
raw_filt = raw.copy().filter(l_freq=cfg.f_low, h_freq=cfg.f_high, verbose=False)
|
| 66 |
+
return raw_filt.get_data().T.astype(np.float32)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
|
| 70 |
+
"""Hilbert envelope per channel using SciPy. Input/Output: (T,C)."""
|
| 71 |
+
analytic = hilbert(x_tc, axis=0)
|
| 72 |
+
return np.abs(analytic).astype(np.float32)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def hilbert_phase_tc(x_tc: np.ndarray) -> np.ndarray:
|
| 76 |
+
"""Hilbert phase per channel using SciPy. Input/Output: (T,C)."""
|
| 77 |
+
analytic = hilbert(x_tc, axis=0)
|
| 78 |
+
return np.angle(analytic).astype(np.float32)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def preprocess_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> dict:
|
| 82 |
+
"""raw(T,C) -> filtered/envelope/phase をまとめて返す"""
|
| 83 |
+
x_tc = ensure_tc(x_tc).astype(np.float32)
|
| 84 |
+
x_filt = bandpass_tc(x_tc, cfg)
|
| 85 |
+
env = hilbert_envelope_tc(x_filt)
|
| 86 |
+
phase = hilbert_phase_tc(x_filt)
|
| 87 |
+
return {
|
| 88 |
+
"fs": float(cfg.fs),
|
| 89 |
+
"raw": x_tc,
|
| 90 |
+
"filtered": x_filt,
|
| 91 |
+
"envelope": env,
|
| 92 |
+
"amplitude": env, # envelope のエイリアス
|
| 93 |
+
"phase": phase
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@st.cache_data(show_spinner=False)
|
| 98 |
+
def preprocess_all_eeglab(
|
| 99 |
+
set_bytes: bytes,
|
| 100 |
+
fdt_bytes: bytes,
|
| 101 |
+
set_name: str,
|
| 102 |
+
fdt_name: str,
|
| 103 |
+
f_low: float,
|
| 104 |
+
f_high: float,
|
| 105 |
+
) -> dict:
|
| 106 |
+
"""
|
| 107 |
+
EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
|
| 108 |
+
fsは読み込んだデータのものを使う。
|
| 109 |
+
"""
|
| 110 |
+
x_tc, fs = load_eeglab_tc_from_bytes(
|
| 111 |
+
set_bytes=set_bytes,
|
| 112 |
+
set_name=set_name,
|
| 113 |
+
fdt_bytes=fdt_bytes,
|
| 114 |
+
fdt_name=fdt_name,
|
| 115 |
+
)
|
| 116 |
+
cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high))
|
| 117 |
+
return preprocess_tc(x_tc, cfg)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@st.cache_data(show_spinner=False)
|
| 121 |
+
def load_mat_candidates_cached(mat_bytes: bytes) -> dict:
|
| 122 |
+
"""MAT candidatesをキャッシュ(UI操作で毎回読まない)"""
|
| 123 |
+
return load_mat_candidates(mat_bytes)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# ============================================================
|
| 127 |
+
# Viewer
|
| 128 |
+
# ============================================================
|
| 129 |
+
def window_slice(X_tc: np.ndarray, start_idx: int, end_idx: int, decim: int) -> np.ndarray:
|
| 130 |
+
start_idx = max(0, min(start_idx, X_tc.shape[0] - 1))
|
| 131 |
+
end_idx = max(start_idx + 1, min(end_idx, X_tc.shape[0]))
|
| 132 |
+
decim = max(1, int(decim))
|
| 133 |
+
return X_tc[start_idx:end_idx:decim, :]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def make_timeseries_figure(
|
| 137 |
+
X_tc: np.ndarray,
|
| 138 |
+
selected_channels: List[int],
|
| 139 |
+
fs: float,
|
| 140 |
+
start_sec: float,
|
| 141 |
+
win_sec: float,
|
| 142 |
+
decim: int,
|
| 143 |
+
offset_mode: bool,
|
| 144 |
+
show_rangeslider: bool,
|
| 145 |
+
signal_type: str = "filtered",
|
| 146 |
+
) -> go.Figure:
|
| 147 |
+
start_idx = int(round(start_sec * fs))
|
| 148 |
+
end_idx = int(round((start_sec + win_sec) * fs))
|
| 149 |
+
|
| 150 |
+
Xw = window_slice(X_tc, start_idx, end_idx, decim)
|
| 151 |
+
Tw = Xw.shape[0]
|
| 152 |
+
t = (np.arange(Tw) * decim + start_idx) / fs
|
| 153 |
+
|
| 154 |
+
fig = go.Figure()
|
| 155 |
+
|
| 156 |
+
if not selected_channels:
|
| 157 |
+
fig.update_layout(
|
| 158 |
+
title="Timeseries (no channel selected)",
|
| 159 |
+
height=450,
|
| 160 |
+
xaxis_title="time (s)",
|
| 161 |
+
yaxis_title="amplitude",
|
| 162 |
+
)
|
| 163 |
+
return fig
|
| 164 |
+
|
| 165 |
+
# 位相データの場合は特別な処理
|
| 166 |
+
is_phase = signal_type == "phase"
|
| 167 |
+
|
| 168 |
+
if offset_mode and len(selected_channels) > 1 and not is_phase:
|
| 169 |
+
per_ch_std = np.std(Xw[:, selected_channels], axis=0)
|
| 170 |
+
base = float(np.median(per_ch_std)) if np.isfinite(np.median(per_ch_std)) and np.median(per_ch_std) > 0 else 1.0
|
| 171 |
+
offset = 5.0 * base
|
| 172 |
+
|
| 173 |
+
for k, ch in enumerate(selected_channels):
|
| 174 |
+
y = Xw[:, ch] + k * offset
|
| 175 |
+
fig.add_trace(go.Scatter(x=t, y=y, mode="lines", name=f"ch{ch}", line=dict(width=1)))
|
| 176 |
+
ylab = "amplitude (offset)"
|
| 177 |
+
else:
|
| 178 |
+
for ch in selected_channels:
|
| 179 |
+
fig.add_trace(go.Scatter(x=t, y=Xw[:, ch], mode="lines", name=f"ch{ch}", line=dict(width=1)))
|
| 180 |
+
|
| 181 |
+
if is_phase:
|
| 182 |
+
ylab = "phase (rad)"
|
| 183 |
+
else:
|
| 184 |
+
ylab = "amplitude"
|
| 185 |
+
|
| 186 |
+
# rangeslider の高さを考慮して調整
|
| 187 |
+
plot_height = 550 if show_rangeslider else 450
|
| 188 |
+
bottom_margin = 150 if show_rangeslider else 80
|
| 189 |
+
|
| 190 |
+
title_text = f"Timeseries: {signal_type} (window={win_sec:.2f}s, start={start_sec:.2f}s, decim={decim})"
|
| 191 |
+
|
| 192 |
+
fig.update_layout(
|
| 193 |
+
title=title_text,
|
| 194 |
+
height=plot_height,
|
| 195 |
+
xaxis_title="time (s)",
|
| 196 |
+
yaxis_title=ylab,
|
| 197 |
+
legend=dict(orientation="h"),
|
| 198 |
+
margin=dict(l=60, r=20, t=80, b=bottom_margin),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# 位相の場合は y軸の範囲を -π ~ π に固定
|
| 202 |
+
if is_phase:
|
| 203 |
+
fig.update_yaxes(range=[-np.pi - 0.5, np.pi + 0.5])
|
| 204 |
+
|
| 205 |
+
if show_rangeslider:
|
| 206 |
+
fig.update_xaxes(
|
| 207 |
+
rangeslider=dict(
|
| 208 |
+
visible=True,
|
| 209 |
+
thickness=0.05,
|
| 210 |
+
)
|
| 211 |
+
)
|
| 212 |
+
else:
|
| 213 |
+
fig.update_xaxes(rangeslider=dict(visible=False))
|
| 214 |
+
|
| 215 |
+
return fig
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
# ============================================================
|
| 219 |
+
# Network (multiple methods) + export
|
| 220 |
+
# ============================================================
|
| 221 |
+
def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray:
|
| 222 |
+
"""
|
| 223 |
+
Envelope (amplitude) の Pearson 相関係数を計算。
|
| 224 |
+
Input: X_tc (T, C) - envelope データ
|
| 225 |
+
Output: W (C, C) - 相関係数の絶対値
|
| 226 |
+
"""
|
| 227 |
+
X = X_tc - X_tc.mean(axis=0, keepdims=True)
|
| 228 |
+
corr = np.corrcoef(X, rowvar=False)
|
| 229 |
+
W = np.abs(corr)
|
| 230 |
+
np.fill_diagonal(W, 0.0)
|
| 231 |
+
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray:
|
| 235 |
+
"""
|
| 236 |
+
Phase の circular correlation (位相同期指標) を計算。
|
| 237 |
+
Input: X_tc (T, C) - phase データ (ラジアン)
|
| 238 |
+
Output: W (C, C) - circular correlation
|
| 239 |
+
|
| 240 |
+
Circular correlation は以下で計算:
|
| 241 |
+
r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t|
|
| 242 |
+
これは Phase Locking Value (PLV) とも呼ばれます。
|
| 243 |
+
"""
|
| 244 |
+
T, C = X_tc.shape
|
| 245 |
+
W = np.zeros((C, C), dtype=np.float32)
|
| 246 |
+
|
| 247 |
+
# 各チャンネルペアについて circular correlation を計算
|
| 248 |
+
for i in range(C):
|
| 249 |
+
for j in range(i + 1, C):
|
| 250 |
+
# 位相差
|
| 251 |
+
phase_diff = X_tc[:, i] - X_tc[:, j]
|
| 252 |
+
# PLV: |mean(exp(i*phase_diff))|
|
| 253 |
+
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
|
| 254 |
+
W[i, j] = plv
|
| 255 |
+
W[j, i] = plv
|
| 256 |
+
|
| 257 |
+
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray:
|
| 261 |
+
"""
|
| 262 |
+
ダミー実装: 単純な相関係数の絶対値
|
| 263 |
+
(後方互換性のため残す)
|
| 264 |
+
"""
|
| 265 |
+
X = X_tc - X_tc.mean(axis=0, keepdims=True)
|
| 266 |
+
corr = np.corrcoef(X, rowvar=False)
|
| 267 |
+
W = np.abs(corr)
|
| 268 |
+
np.fill_diagonal(W, 0.0)
|
| 269 |
+
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def threshold_edges(W: np.ndarray, thr: float) -> List[Tuple[int, int, float]]:
|
| 273 |
+
C = W.shape[0]
|
| 274 |
+
edges: List[Tuple[int, int, float]] = []
|
| 275 |
+
for i in range(C):
|
| 276 |
+
for j in range(i + 1, C):
|
| 277 |
+
w = float(W[i, j])
|
| 278 |
+
if w >= thr:
|
| 279 |
+
edges.append((i, j, w))
|
| 280 |
+
edges.sort(key=lambda x: x[2], reverse=True)
|
| 281 |
+
return edges
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray:
|
| 285 |
+
if weighted:
|
| 286 |
+
A = W.copy()
|
| 287 |
+
A[A < thr] = 0.0
|
| 288 |
+
np.fill_diagonal(A, 0.0)
|
| 289 |
+
return A
|
| 290 |
+
A = (W >= thr).astype(int)
|
| 291 |
+
np.fill_diagonal(A, 0)
|
| 292 |
+
return A
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray:
|
| 296 |
+
"""
|
| 297 |
+
Louvain法でクラスタリングを実行。
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
W: 重み行列 (C, C)
|
| 301 |
+
thr: 閾値(これ以下のエッジは削除)
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
clusters: クラスタID配列 (C,)
|
| 305 |
+
"""
|
| 306 |
+
if not LOUVAIN_AVAILABLE:
|
| 307 |
+
# Louvainが使えない場合は全ノードを同じクラスタに
|
| 308 |
+
return np.zeros(W.shape[0], dtype=int)
|
| 309 |
+
|
| 310 |
+
# NetworkXグラフを作成
|
| 311 |
+
G = nx.Graph()
|
| 312 |
+
C = W.shape[0]
|
| 313 |
+
G.add_nodes_from(range(C))
|
| 314 |
+
|
| 315 |
+
# 閾値以上のエッジを追加
|
| 316 |
+
for i in range(C):
|
| 317 |
+
for j in range(i + 1, C):
|
| 318 |
+
if W[i, j] >= thr:
|
| 319 |
+
G.add_edge(i, j, weight=W[i, j])
|
| 320 |
+
|
| 321 |
+
# Louvain法でコミュニティ検出
|
| 322 |
+
partition = community_louvain.best_partition(G, weight='weight')
|
| 323 |
+
|
| 324 |
+
# クラスタIDの配列に変換
|
| 325 |
+
clusters = np.array([partition[i] for i in range(C)])
|
| 326 |
+
|
| 327 |
+
return clusters
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def get_cluster_colors(clusters: np.ndarray) -> List[str]:
|
| 331 |
+
"""
|
| 332 |
+
クラスタIDから色のリストを生成。
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
clusters: クラスタID配列 (C,)
|
| 336 |
+
|
| 337 |
+
Returns:
|
| 338 |
+
colors: 色のリスト
|
| 339 |
+
"""
|
| 340 |
+
import colorsys
|
| 341 |
+
|
| 342 |
+
n_clusters = len(np.unique(clusters))
|
| 343 |
+
|
| 344 |
+
# クラスタ数に応じて色相を均等に分割
|
| 345 |
+
colors = []
|
| 346 |
+
for cluster_id in clusters:
|
| 347 |
+
hue = cluster_id / max(n_clusters, 1)
|
| 348 |
+
r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.95)
|
| 349 |
+
colors.append(f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})')
|
| 350 |
+
|
| 351 |
+
return colors
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def make_network_figure(W: np.ndarray, thr: float, use_louvain: bool = True) -> tuple[go.Figure, int]:
|
| 355 |
+
C = W.shape[0]
|
| 356 |
+
angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
|
| 357 |
+
xs = np.cos(angles)
|
| 358 |
+
ys = np.sin(angles)
|
| 359 |
+
|
| 360 |
+
edges = threshold_edges(W, thr)
|
| 361 |
+
fig = go.Figure()
|
| 362 |
+
|
| 363 |
+
# エッジの重みの範囲を取得(色と太さのスケーリング用)
|
| 364 |
+
if edges:
|
| 365 |
+
weights = [w for _, _, w in edges]
|
| 366 |
+
min_w = min(weights)
|
| 367 |
+
max_w = max(weights)
|
| 368 |
+
weight_range = max_w - min_w if max_w > min_w else 1.0
|
| 369 |
+
else:
|
| 370 |
+
min_w = 0
|
| 371 |
+
max_w = 1
|
| 372 |
+
weight_range = 1.0
|
| 373 |
+
|
| 374 |
+
# レインボーカラーマップ関数 (0=青 → 0.5=緑/黄 → 1=赤)
|
| 375 |
+
def get_rainbow_color(norm_val):
|
| 376 |
+
"""正規化された値 (0-1) からレインボーカラーを生成"""
|
| 377 |
+
import colorsys
|
| 378 |
+
# HSVのHue: 240°(青) → 0°(赤) に変換
|
| 379 |
+
hue = (1.0 - norm_val) * 0.67 # 0.67 ≈ 240/360 (青)
|
| 380 |
+
r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
|
| 381 |
+
return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
|
| 382 |
+
|
| 383 |
+
# エッジを描画(重みに応じて色と太さを変える)
|
| 384 |
+
for (i, j, w) in edges:
|
| 385 |
+
# 正規化された重み (0-1)
|
| 386 |
+
norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
|
| 387 |
+
|
| 388 |
+
# レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤)
|
| 389 |
+
color = get_rainbow_color(norm_w)
|
| 390 |
+
|
| 391 |
+
# 太さ: 重みに比例 (0.5-4の範囲)
|
| 392 |
+
line_width = 0.5 + 3.5 * norm_w
|
| 393 |
+
|
| 394 |
+
fig.add_trace(
|
| 395 |
+
go.Scatter(
|
| 396 |
+
x=[xs[i], xs[j]],
|
| 397 |
+
y=[ys[i], ys[j]],
|
| 398 |
+
mode="lines",
|
| 399 |
+
hoverinfo="text",
|
| 400 |
+
hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}",
|
| 401 |
+
line=dict(width=line_width, color=color),
|
| 402 |
+
showlegend=False,
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Louvainクラスタリング
|
| 407 |
+
if use_louvain and LOUVAIN_AVAILABLE:
|
| 408 |
+
clusters = compute_louvain_clusters(W, thr)
|
| 409 |
+
node_colors = get_cluster_colors(clusters)
|
| 410 |
+
n_clusters = len(np.unique(clusters))
|
| 411 |
+
title_suffix = f" | Louvain clusters: {n_clusters}"
|
| 412 |
+
else:
|
| 413 |
+
node_colors = ['#FFD700'] * C # デフォルトのゴールド
|
| 414 |
+
clusters = np.zeros(C, dtype=int)
|
| 415 |
+
title_suffix = ""
|
| 416 |
+
|
| 417 |
+
# ノードを描画
|
| 418 |
+
fig.add_trace(
|
| 419 |
+
go.Scatter(
|
| 420 |
+
x=xs,
|
| 421 |
+
y=ys,
|
| 422 |
+
mode="markers+text",
|
| 423 |
+
text=[f"{k}" for k in range(C)],
|
| 424 |
+
textposition="bottom center",
|
| 425 |
+
marker=dict(
|
| 426 |
+
size=14,
|
| 427 |
+
color=node_colors,
|
| 428 |
+
line=dict(width=2, color='white')
|
| 429 |
+
),
|
| 430 |
+
hoverinfo="text",
|
| 431 |
+
hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
|
| 432 |
+
showlegend=False,
|
| 433 |
+
)
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
fig.update_layout(
|
| 437 |
+
title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
|
| 438 |
+
height=500,
|
| 439 |
+
xaxis=dict(visible=False),
|
| 440 |
+
yaxis=dict(visible=False),
|
| 441 |
+
margin=dict(l=10, r=10, t=50, b=10),
|
| 442 |
+
paper_bgcolor='rgba(0,0,0,0.9)',
|
| 443 |
+
plot_bgcolor='rgba(0,0,0,0.9)',
|
| 444 |
+
)
|
| 445 |
+
fig.update_yaxes(scaleanchor="x", scaleratio=1)
|
| 446 |
+
|
| 447 |
+
# カラーバー的な説明を追加
|
| 448 |
+
if edges:
|
| 449 |
+
fig.add_annotation(
|
| 450 |
+
text=f"Edge color/width: weak (blue/thin) → medium (green/yellow) → strong (red/thick)<br>Weight range: {min_w:.3f} - {max_w:.3f}",
|
| 451 |
+
xref="paper", yref="paper",
|
| 452 |
+
x=0.5, y=-0.05,
|
| 453 |
+
showarrow=False,
|
| 454 |
+
font=dict(size=10, color='white'),
|
| 455 |
+
xanchor='center',
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
return fig, len(edges)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def make_edgecount_curve(W: np.ndarray) -> go.Figure:
|
| 462 |
+
vals = np.sort(W[np.triu_indices(W.shape[0], k=1)])
|
| 463 |
+
thr_grid = np.linspace(float(vals.max()), float(vals.min()), 120) if vals.size else np.array([0.0])
|
| 464 |
+
counts = [len(threshold_edges(W, float(thr))) for thr in thr_grid]
|
| 465 |
+
|
| 466 |
+
fig = go.Figure()
|
| 467 |
+
fig.add_trace(go.Scatter(x=thr_grid, y=counts, mode="lines"))
|
| 468 |
+
fig.update_layout(
|
| 469 |
+
title="Edge count vs threshold (lower thr => more edges)",
|
| 470 |
+
xaxis_title="threshold",
|
| 471 |
+
yaxis_title="edge count",
|
| 472 |
+
height=300,
|
| 473 |
+
)
|
| 474 |
+
return fig
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def to_csv_bytes_matrix(mat: np.ndarray, fmt: str) -> bytes:
|
| 478 |
+
buf = io.StringIO()
|
| 479 |
+
np.savetxt(buf, mat, delimiter=",", fmt=fmt)
|
| 480 |
+
return buf.getvalue().encode("utf-8")
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
def to_csv_bytes_edges(edges: List[Tuple[int, int, float]]) -> bytes:
|
| 484 |
+
buf = io.StringIO()
|
| 485 |
+
buf.write("source,target,weight\n")
|
| 486 |
+
for i, j, w in edges:
|
| 487 |
+
buf.write(f"{i},{j},{w:.6f}\n")
|
| 488 |
+
return buf.getvalue().encode("utf-8")
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
# ============================================================
|
| 492 |
+
# Sidebar UI
|
| 493 |
+
# ============================================================
|
| 494 |
+
st.sidebar.header("Input format")
|
| 495 |
+
input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0)
|
| 496 |
+
|
| 497 |
+
st.sidebar.header("Preprocess (auto)")
|
| 498 |
+
f_low = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=8.0, step=0.5)
|
| 499 |
+
f_high = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=12.0, step=0.5)
|
| 500 |
+
|
| 501 |
+
st.sidebar.header("Viewer controls")
|
| 502 |
+
win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1)
|
| 503 |
+
decim = st.sidebar.selectbox("Decimation (間引き)", options=[1, 2, 5, 10, 20, 50], index=1)
|
| 504 |
+
offset_mode = st.sidebar.checkbox("重ね描画のオフセット表示", value=True)
|
| 505 |
+
show_rangeslider = st.sidebar.checkbox("Plotly rangesliderを表示", value=False)
|
| 506 |
+
signal_view = st.sidebar.radio(
|
| 507 |
+
"表示する信号",
|
| 508 |
+
["raw", "filtered", "amplitude", "phase"],
|
| 509 |
+
index=1,
|
| 510 |
+
help="raw: 生信号, filtered: バンドパス後, amplitude: Hilbert振幅(envelope), phase: Hilbert位相"
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
st.title("EEG timeseries viewer + network estimation")
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# ============================================================
|
| 517 |
+
# Load + preprocess (EEGLAB / MAT)
|
| 518 |
+
# ============================================================
|
| 519 |
+
if input_mode.startswith("EEGLAB"):
|
| 520 |
+
st.sidebar.header("Upload (.set + .fdt)")
|
| 521 |
+
uploaded_files = st.sidebar.file_uploader(
|
| 522 |
+
"Upload EEGLAB files",
|
| 523 |
+
type=["set", "fdt"],
|
| 524 |
+
accept_multiple_files=True,
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if uploaded_files:
|
| 528 |
+
set_file, fdt_file = pick_set_fdt(uploaded_files)
|
| 529 |
+
if set_file is None or fdt_file is None:
|
| 530 |
+
st.warning("`.set` と `.fdt` の両方をアップロードしてください。")
|
| 531 |
+
else:
|
| 532 |
+
try:
|
| 533 |
+
with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."):
|
| 534 |
+
prep = preprocess_all_eeglab(
|
| 535 |
+
set_bytes=set_file.getvalue(),
|
| 536 |
+
fdt_bytes=fdt_file.getvalue(),
|
| 537 |
+
set_name=set_file.name,
|
| 538 |
+
fdt_name=fdt_file.name,
|
| 539 |
+
f_low=float(f_low),
|
| 540 |
+
f_high=float(f_high),
|
| 541 |
+
)
|
| 542 |
+
st.session_state["prep"] = prep
|
| 543 |
+
st.session_state["W"] = None
|
| 544 |
+
st.success(f"Loaded & preprocessed. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
|
| 545 |
+
except Exception as e:
|
| 546 |
+
st.session_state.pop("prep", None)
|
| 547 |
+
st.session_state["W"] = None
|
| 548 |
+
st.error(f"読み込み/前処理エラー: {e}")
|
| 549 |
+
|
| 550 |
+
else:
|
| 551 |
+
st.sidebar.header("Upload (.mat)")
|
| 552 |
+
mat_file = st.sidebar.file_uploader("Upload .mat", type=["mat"])
|
| 553 |
+
|
| 554 |
+
if mat_file is not None:
|
| 555 |
+
mat_bytes = mat_file.getvalue()
|
| 556 |
+
try:
|
| 557 |
+
cands = load_mat_candidates_cached(mat_bytes)
|
| 558 |
+
if not cands:
|
| 559 |
+
st.error("数値の1D/2D配列が見つかりませんでした。")
|
| 560 |
+
st.info("MATファイルの構造を確認しています...")
|
| 561 |
+
|
| 562 |
+
# デバッグ: MATファイルの中身を表示
|
| 563 |
+
try:
|
| 564 |
+
from scipy.io import loadmat
|
| 565 |
+
mat_data = loadmat(io.BytesIO(mat_bytes))
|
| 566 |
+
st.write("**MATファイルに含まれる変数:**")
|
| 567 |
+
for k, v in mat_data.items():
|
| 568 |
+
if not k.startswith('__'):
|
| 569 |
+
if isinstance(v, np.ndarray):
|
| 570 |
+
st.write(f"- `{k}`: shape={v.shape}, dtype={v.dtype}, ndim={v.ndim}")
|
| 571 |
+
else:
|
| 572 |
+
st.write(f"- `{k}`: type={type(v).__name__}")
|
| 573 |
+
except Exception as e:
|
| 574 |
+
st.write(f"デバッグ情報の取得に失敗: {e}")
|
| 575 |
+
|
| 576 |
+
# HDF5形式の場合も試す
|
| 577 |
+
try:
|
| 578 |
+
import h5py
|
| 579 |
+
import tempfile
|
| 580 |
+
with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp:
|
| 581 |
+
tmp.write(mat_bytes)
|
| 582 |
+
tmp_path = tmp.name
|
| 583 |
+
|
| 584 |
+
st.write("**HDF5形式として読み込み中...**")
|
| 585 |
+
with h5py.File(tmp_path, 'r') as f:
|
| 586 |
+
def show_structure(name, obj):
|
| 587 |
+
if isinstance(obj, h5py.Dataset):
|
| 588 |
+
st.write(f"- `{name}`: shape={obj.shape}, dtype={obj.dtype}")
|
| 589 |
+
f.visititems(show_structure)
|
| 590 |
+
|
| 591 |
+
import os
|
| 592 |
+
os.unlink(tmp_path)
|
| 593 |
+
except Exception as e2:
|
| 594 |
+
st.write(f"HDF5としても読み込めませんでした: {e2}")
|
| 595 |
+
else:
|
| 596 |
+
key = st.sidebar.selectbox("EEG配列(変数)を選択", options=list(cands.keys()))
|
| 597 |
+
fs_mat = st.sidebar.number_input("Sampling rate (Hz)", min_value=0.1, value=256.0, step=0.1)
|
| 598 |
+
|
| 599 |
+
# 変数が選択されたら自動的に前処理を実行
|
| 600 |
+
if key:
|
| 601 |
+
x = cands[key]
|
| 602 |
+
st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}")
|
| 603 |
+
try:
|
| 604 |
+
with st.spinner("Preprocessing (bandpass + hilbert)..."):
|
| 605 |
+
cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low), f_high=float(f_high))
|
| 606 |
+
prep = preprocess_tc(x, cfg)
|
| 607 |
+
|
| 608 |
+
st.session_state["prep"] = prep
|
| 609 |
+
st.session_state["W"] = None
|
| 610 |
+
st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
|
| 611 |
+
except Exception as e:
|
| 612 |
+
st.session_state.pop("prep", None)
|
| 613 |
+
st.session_state["W"] = None
|
| 614 |
+
st.error(f"前処理エラー: {e}")
|
| 615 |
+
import traceback
|
| 616 |
+
st.code(traceback.format_exc())
|
| 617 |
+
except Exception as e:
|
| 618 |
+
st.session_state.pop("prep", None)
|
| 619 |
+
st.session_state["W"] = None
|
| 620 |
+
st.error(f".mat 読み込みエラー: {e}")
|
| 621 |
+
import traceback
|
| 622 |
+
st.code(traceback.format_exc())
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
if "prep" not in st.session_state:
|
| 626 |
+
st.info("左のサイドバーからデータをアップロードしてください。")
|
| 627 |
+
st.stop()
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
# ============================================================
|
| 631 |
+
# Viewer
|
| 632 |
+
# ============================================================
|
| 633 |
+
prep = st.session_state["prep"]
|
| 634 |
+
fs = float(prep["fs"])
|
| 635 |
+
X_tc = prep[signal_view]
|
| 636 |
+
T, C = X_tc.shape
|
| 637 |
+
|
| 638 |
+
duration_sec = (T - 1) / fs if T > 1 else 0.0
|
| 639 |
+
max_start = max(0.0, float(duration_sec - win_sec))
|
| 640 |
+
|
| 641 |
+
start_sec = st.sidebar.slider(
|
| 642 |
+
"Start time (sec)",
|
| 643 |
+
min_value=0.0,
|
| 644 |
+
max_value=float(max_start),
|
| 645 |
+
value=0.0,
|
| 646 |
+
step=float(max(0.01, win_sec / 200)),
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
st.sidebar.header("Channels")
|
| 650 |
+
|
| 651 |
+
# チャンネル選択の便利機能
|
| 652 |
+
col_ch1, col_ch2 = st.sidebar.columns(2)
|
| 653 |
+
with col_ch1:
|
| 654 |
+
select_all = st.button("全選択")
|
| 655 |
+
with col_ch2:
|
| 656 |
+
deselect_all = st.button("全解除")
|
| 657 |
+
|
| 658 |
+
# 範囲選択
|
| 659 |
+
with st.sidebar.expander("📊 範囲で選択"):
|
| 660 |
+
range_start = st.number_input("開始ch", min_value=0, max_value=C-1, value=0, step=1)
|
| 661 |
+
range_end = st.number_input("終了ch", min_value=0, max_value=C-1, value=min(C-1, 7), step=1)
|
| 662 |
+
if st.button("範囲を選択"):
|
| 663 |
+
st.session_state["selected_channels"] = list(range(int(range_start), int(range_end) + 1))
|
| 664 |
+
|
| 665 |
+
# プリセット選択
|
| 666 |
+
with st.sidebar.expander("⚡ プリセット"):
|
| 667 |
+
preset_col1, preset_col2 = st.columns(2)
|
| 668 |
+
with preset_col1:
|
| 669 |
+
if st.button("前頭部 (0-15)"):
|
| 670 |
+
st.session_state["selected_channels"] = list(range(min(16, C)))
|
| 671 |
+
with preset_col2:
|
| 672 |
+
if st.button("頭頂部 (16-31)"):
|
| 673 |
+
st.session_state["selected_channels"] = list(range(16, min(32, C)))
|
| 674 |
+
preset_col3, preset_col4 = st.columns(2)
|
| 675 |
+
with preset_col3:
|
| 676 |
+
if st.button("側頭部 (32-47)"):
|
| 677 |
+
st.session_state["selected_channels"] = list(range(32, min(48, C)))
|
| 678 |
+
with preset_col4:
|
| 679 |
+
if st.button("後頭部 (48-63)"):
|
| 680 |
+
st.session_state["selected_channels"] = list(range(48, min(64, C)))
|
| 681 |
+
|
| 682 |
+
# セッションステートの初期化
|
| 683 |
+
if "selected_channels" not in st.session_state:
|
| 684 |
+
st.session_state["selected_channels"] = list(range(min(C, 8)))
|
| 685 |
+
|
| 686 |
+
# ボタンによる選択の処理
|
| 687 |
+
if select_all:
|
| 688 |
+
st.session_state["selected_channels"] = list(range(C))
|
| 689 |
+
if deselect_all:
|
| 690 |
+
st.session_state["selected_channels"] = []
|
| 691 |
+
|
| 692 |
+
# メインの選択UI(最大表示数を制限)
|
| 693 |
+
max_display = 20 # multiselect で一度に表示する数を制限
|
| 694 |
+
if C <= max_display:
|
| 695 |
+
selected_channels = st.sidebar.multiselect(
|
| 696 |
+
f"表示するチャンネル(全{C}ch)",
|
| 697 |
+
options=list(range(C)),
|
| 698 |
+
default=st.session_state["selected_channels"],
|
| 699 |
+
key="ch_select",
|
| 700 |
+
)
|
| 701 |
+
else:
|
| 702 |
+
# 大量のチャンネルがある場合は、選択済みのものだけ表示
|
| 703 |
+
st.sidebar.caption(f"選択中: {len(st.session_state['selected_channels'])} / {C} channels")
|
| 704 |
+
|
| 705 |
+
# 個別追加
|
| 706 |
+
add_ch = st.sidebar.number_input(
|
| 707 |
+
"チャンネルを追加",
|
| 708 |
+
min_value=0,
|
| 709 |
+
max_value=C-1,
|
| 710 |
+
value=0,
|
| 711 |
+
step=1,
|
| 712 |
+
key="add_ch_input"
|
| 713 |
+
)
|
| 714 |
+
col_add, col_remove = st.sidebar.columns(2)
|
| 715 |
+
with col_add:
|
| 716 |
+
if st.button("➕ 追加"):
|
| 717 |
+
if add_ch not in st.session_state["selected_channels"]:
|
| 718 |
+
st.session_state["selected_channels"].append(int(add_ch))
|
| 719 |
+
st.session_state["selected_channels"].sort()
|
| 720 |
+
with col_remove:
|
| 721 |
+
if st.button("➖ 削除"):
|
| 722 |
+
if add_ch in st.session_state["selected_channels"]:
|
| 723 |
+
st.session_state["selected_channels"].remove(int(add_ch))
|
| 724 |
+
|
| 725 |
+
# 現在の選択を表示
|
| 726 |
+
if st.session_state["selected_channels"]:
|
| 727 |
+
selected_str = ", ".join(map(str, st.session_state["selected_channels"][:10]))
|
| 728 |
+
if len(st.session_state["selected_channels"]) > 10:
|
| 729 |
+
selected_str += f", ... (+{len(st.session_state['selected_channels']) - 10})"
|
| 730 |
+
st.sidebar.text(f"選択済み: {selected_str}")
|
| 731 |
+
|
| 732 |
+
selected_channels = st.session_state["selected_channels"]
|
| 733 |
+
|
| 734 |
+
# セッションステートを更新(multiselectを使った場合)
|
| 735 |
+
if C <= max_display:
|
| 736 |
+
st.session_state["selected_channels"] = selected_channels
|
| 737 |
+
|
| 738 |
+
col1, col2 = st.columns([2, 1])
|
| 739 |
+
with col1:
|
| 740 |
+
fig_ts = make_timeseries_figure(
|
| 741 |
+
X_tc=X_tc,
|
| 742 |
+
selected_channels=selected_channels,
|
| 743 |
+
fs=fs,
|
| 744 |
+
start_sec=float(start_sec),
|
| 745 |
+
win_sec=float(win_sec),
|
| 746 |
+
decim=int(decim),
|
| 747 |
+
offset_mode=bool(offset_mode),
|
| 748 |
+
show_rangeslider=bool(show_rangeslider),
|
| 749 |
+
signal_type=signal_view,
|
| 750 |
+
)
|
| 751 |
+
st.plotly_chart(fig_ts)
|
| 752 |
+
|
| 753 |
+
with col2:
|
| 754 |
+
st.subheader("Data info")
|
| 755 |
+
signal_desc = {
|
| 756 |
+
"raw": "生信号(前処理なし)",
|
| 757 |
+
"filtered": f"バンドパスフィルタ後 ({f_low}-{f_high} Hz)",
|
| 758 |
+
"amplitude": "Hilbert振幅 (envelope)",
|
| 759 |
+
"phase": "Hilbert位相 (-π ~ π)"
|
| 760 |
+
}
|
| 761 |
+
st.write(f"- view: **{signal_view}** ({signal_desc.get(signal_view, '')})")
|
| 762 |
+
st.write(f"- fs: **{fs:.2f} Hz**")
|
| 763 |
+
st.write(f"- T: {T} samples")
|
| 764 |
+
st.write(f"- C: {C} channels")
|
| 765 |
+
st.write(f"- duration: {duration_sec:.2f} sec")
|
| 766 |
+
|
| 767 |
+
if signal_view == "phase":
|
| 768 |
+
st.caption("※ 位相は -π (rad) から π (rad) の範囲で表示されます")
|
| 769 |
+
|
| 770 |
+
st.caption("※ 大規模データは window + decimation 推奨。rangesliderは重い場合OFF。")
|
| 771 |
+
|
| 772 |
+
st.divider()
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
# ============================================================
|
| 776 |
+
# Estimation
|
| 777 |
+
# ============================================================
|
| 778 |
+
st.subheader("Network estimation")
|
| 779 |
+
|
| 780 |
+
# 推定手法の選択
|
| 781 |
+
estimation_method = st.radio(
|
| 782 |
+
"推定手法を選択",
|
| 783 |
+
options=[
|
| 784 |
+
"envelope_corr",
|
| 785 |
+
"phase_corr",
|
| 786 |
+
],
|
| 787 |
+
format_func=lambda x: {
|
| 788 |
+
"envelope_corr": "Envelope correlation (振幅の相関)",
|
| 789 |
+
"phase_corr": "Phase circular correlation (位相同期, PLV)",
|
| 790 |
+
}[x],
|
| 791 |
+
horizontal=True,
|
| 792 |
+
help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_corr: 位相の circular correlation (Phase Locking Value)",
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
# 推定手法の説明
|
| 796 |
+
method_info = {
|
| 797 |
+
"envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。",
|
| 798 |
+
"phase_corr": "**Phase circular correlation (PLV)**: 位相間の circular correlation を計算します。Phase Locking Value (PLV) とも呼ばれ、位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
|
| 799 |
+
}
|
| 800 |
+
st.info(method_info[estimation_method])
|
| 801 |
+
|
| 802 |
+
# セッションステートから前回の手法と W を取得
|
| 803 |
+
last_method = st.session_state.get("last_estimation_method")
|
| 804 |
+
W = st.session_state.get("W")
|
| 805 |
+
|
| 806 |
+
# 推定が必要かチェック(初回 or 手法変更)
|
| 807 |
+
need_estimation = (W is None) or (last_method != estimation_method)
|
| 808 |
+
|
| 809 |
+
if need_estimation:
|
| 810 |
+
with st.spinner(f"推定中... ({estimation_method})"):
|
| 811 |
+
if estimation_method == "envelope_corr":
|
| 812 |
+
X_in = prep["amplitude"]
|
| 813 |
+
W = estimate_network_envelope_corr(X_in)
|
| 814 |
+
elif estimation_method == "phase_corr":
|
| 815 |
+
X_in = prep["phase"]
|
| 816 |
+
W = estimate_network_phase_corr(X_in)
|
| 817 |
+
else:
|
| 818 |
+
st.error("未知の推定手法です")
|
| 819 |
+
st.stop()
|
| 820 |
+
|
| 821 |
+
# セッションステートに保存
|
| 822 |
+
st.session_state["W"] = W
|
| 823 |
+
st.session_state["last_estimation_method"] = estimation_method
|
| 824 |
+
st.success(f"✅ 推定完了: {estimation_method} (ネットワークサイズ: {W.shape[0]} nodes)")
|
| 825 |
+
else:
|
| 826 |
+
st.success(f"✓ 推定済み: **{estimation_method}** (ネットワークサイズ: {W.shape[0]} nodes)")
|
| 827 |
+
|
| 828 |
+
# この時点で W は必ず存在する
|
| 829 |
+
# 閾値スライダーとネットワーク図の表示
|
| 830 |
+
wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
|
| 831 |
+
|
| 832 |
+
col_thr1, col_thr2 = st.columns([3, 1])
|
| 833 |
+
with col_thr1:
|
| 834 |
+
thr = st.slider(
|
| 835 |
+
"閾値 (threshold) ※下げるほどエッジが増えます",
|
| 836 |
+
min_value=0.0,
|
| 837 |
+
max_value=max(0.0001, wmax),
|
| 838 |
+
value=min(0.5, wmax),
|
| 839 |
+
step=max(wmax / 200, 0.001),
|
| 840 |
+
)
|
| 841 |
+
with col_thr2:
|
| 842 |
+
use_louvain = st.checkbox(
|
| 843 |
+
"Louvainクラスタリング",
|
| 844 |
+
value=True,
|
| 845 |
+
disabled=not LOUVAIN_AVAILABLE,
|
| 846 |
+
help="ノードの色をコミュニティ検出結果で塗り分けます"
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
net_col1, net_col2 = st.columns([2, 1])
|
| 850 |
+
with net_col1:
|
| 851 |
+
fig_net, edge_n = make_network_figure(W, float(thr), use_louvain=use_louvain)
|
| 852 |
+
st.plotly_chart(fig_net)
|
| 853 |
|
| 854 |
+
with net_col2:
|
| 855 |
+
st.metric("Edges", edge_n)
|
| 856 |
+
st.plotly_chart(make_edgecount_curve(W))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|