stardust-coder commited on
Commit
5ec4552
·
1 Parent(s): b11ec91

[add] added louvain clustering

Browse files
Files changed (3) hide show
  1. requirements.txt +3 -2
  2. src/loader.py +112 -4
  3. src/streamlit_app.py +141 -9
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
- streamlit
2
  numpy
3
  plotly
4
  scipy
5
  mne
6
  h5py
7
- networkx
 
 
1
+ streamlit==1.24.0
2
  numpy
3
  plotly
4
  scipy
5
  mne
6
  h5py
7
+ networkx
8
+ python-louvain
src/loader.py CHANGED
@@ -42,8 +42,103 @@ def same_stem(a_name: str, b_name: str) -> bool:
42
  return a_stem == b_stem
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def _load_eeglab_hdf5(set_path: str, fdt_path: Optional[str] = None, debug: bool = False) -> Tuple[np.ndarray, float]:
46
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  Load EEGLAB .set file saved in MATLAB v7.3 (HDF5) format using h5py.
48
  Returns: (x_tc, fs) where x_tc is (T, C)
49
  """
@@ -227,12 +322,13 @@ def load_eeglab_tc_from_bytes(
227
  set_name: str,
228
  fdt_bytes: Optional[bytes] = None,
229
  fdt_name: Optional[str] = None,
230
- ) -> Tuple[np.ndarray, float]:
231
  """
232
  Load EEGLAB .set (and optional .fdt) from bytes using MNE or h5py.
233
  Returns:
234
  x_tc: (T, C) float32
235
  fs: sampling rate (Hz)
 
236
 
237
  Notes:
238
  - 多くのEEGLABは .set が .fdt を参照するため、同じディレクトリに同名で置く必要があります。
@@ -261,7 +357,11 @@ def load_eeglab_tc_from_bytes(
261
  raw = mne.io.read_raw_eeglab(set_path, preload=True, verbose=False)
262
  fs = float(raw.info["sfreq"])
263
  x_tc = raw.get_data().T # (T,C)
264
- return x_tc.astype(np.float32), fs
 
 
 
 
265
 
266
  except Exception as e_raw:
267
  # 2) Epochsとして読む(エポックデータ用)
@@ -273,7 +373,11 @@ def load_eeglab_tc_from_bytes(
273
  # ここは方針を選ぶ:平均 or 連結
274
  x_mean = x.mean(axis=0) # (C,T)
275
  x_tc = x_mean.T # (T,C)
276
- return x_tc.astype(np.float32), fs
 
 
 
 
277
 
278
  except Exception as e_ep:
279
  # 3) HDF5形式として読む(MATLAB v7.3)
@@ -285,7 +389,11 @@ def load_eeglab_tc_from_bytes(
285
  if 'streamlit' in sys.modules:
286
  debug = True
287
  x_tc, fs = _load_eeglab_hdf5(set_path, fdt_path=fdt_path, debug=debug)
288
- return x_tc, fs
 
 
 
 
289
 
290
  except Exception as e_hdf5:
291
  # すべて失敗した場合
 
42
  return a_stem == b_stem
43
 
44
 
45
+ def extract_electrode_positions_2d(set_path: str) -> np.ndarray:
46
+ """
47
+ EEGLABファイルから電極位置(2D)を抽出。
48
+
49
+ Returns:
50
+ pos: (C, 2) 電極の2D座標、取得できない場合はNone
51
+ """
52
+ try:
53
+ # MNEで読み込み
54
+ raw = mne.io.read_raw_eeglab(set_path, preload=False, verbose=False)
55
+ montage = raw.get_montage()
56
+
57
+ if montage is None:
58
+ return None
59
+
60
+ # 3D座標を取得
61
+ pos_3d = montage.get_positions()['ch_pos']
62
+
63
+ if not pos_3d:
64
+ return None
65
+
66
+ # チャンネル名順に並べ替え
67
+ ch_names = raw.ch_names
68
+ positions = []
69
+ for ch_name in ch_names:
70
+ if ch_name in pos_3d:
71
+ positions.append(pos_3d[ch_name])
72
+ else:
73
+ # 座標がないチャンネルは原点に配置
74
+ positions.append([0, 0, 0])
75
+
76
+ positions = np.array(positions)
77
+
78
+ # 3D -> 2D 投影(上から見た図)
79
+ # x, y座標を使用し、正規化
80
+ pos_2d = positions[:, :2]
81
+
82
+ # 正規化: 最大距離が1になるようにスケーリング
83
+ max_dist = np.max(np.sqrt(np.sum(pos_2d**2, axis=1)))
84
+ if max_dist > 0:
85
+ pos_2d = pos_2d / max_dist * 0.85 # 0.85倍で頭の輪郭内に収める
86
+
87
+ return pos_2d.astype(np.float32)
88
+
89
+ except Exception as e:
90
+ print(f"電極位置の抽出に失敗: {e}")
91
+ return None
92
+
93
+
94
  def _load_eeglab_hdf5(set_path: str, fdt_path: Optional[str] = None, debug: bool = False) -> Tuple[np.ndarray, float]:
95
  """
96
+ EEGLABファイルから電極位置(2D)を抽出。
97
+
98
+ Returns:
99
+ pos: (C, 2) 電極の2D座標、取得できない場合はNone
100
+ """
101
+ try:
102
+ # MNEで読み込み
103
+ raw = mne.io.read_raw_eeglab(set_path, preload=False, verbose=False)
104
+ montage = raw.get_montage()
105
+
106
+ if montage is None:
107
+ return None
108
+
109
+ # 3D座標を取得
110
+ pos_3d = montage.get_positions()['ch_pos']
111
+
112
+ if not pos_3d:
113
+ return None
114
+
115
+ # チャンネル名順に並べ替え
116
+ ch_names = raw.ch_names
117
+ positions = []
118
+ for ch_name in ch_names:
119
+ if ch_name in pos_3d:
120
+ positions.append(pos_3d[ch_name])
121
+ else:
122
+ # 座標がないチャンネルは原点に配置
123
+ positions.append([0, 0, 0])
124
+
125
+ positions = np.array(positions)
126
+
127
+ # 3D -> 2D 投影(上から見た図)
128
+ # x, y座標を使用し、正規化
129
+ pos_2d = positions[:, :2]
130
+
131
+ # 正規化: 最大距離が1になるようにスケーリング
132
+ max_dist = np.max(np.sqrt(np.sum(pos_2d**2, axis=1)))
133
+ if max_dist > 0:
134
+ pos_2d = pos_2d / max_dist * 0.85 # 0.85倍で頭の輪郭内に収める
135
+
136
+ return pos_2d.astype(np.float32)
137
+
138
+ except Exception as e:
139
+ print(f"電極位置の抽出に失敗: {e}")
140
+ return None
141
+ """
142
  Load EEGLAB .set file saved in MATLAB v7.3 (HDF5) format using h5py.
143
  Returns: (x_tc, fs) where x_tc is (T, C)
144
  """
 
322
  set_name: str,
323
  fdt_bytes: Optional[bytes] = None,
324
  fdt_name: Optional[str] = None,
325
+ ) -> Tuple[np.ndarray, float, Optional[np.ndarray]]:
326
  """
327
  Load EEGLAB .set (and optional .fdt) from bytes using MNE or h5py.
328
  Returns:
329
  x_tc: (T, C) float32
330
  fs: sampling rate (Hz)
331
+ electrode_pos: (C, 2) float32 or None - 電極の2D座標
332
 
333
  Notes:
334
  - 多くのEEGLABは .set が .fdt を参照するため、同じディレクトリに同名で置く必要があります。
 
357
  raw = mne.io.read_raw_eeglab(set_path, preload=True, verbose=False)
358
  fs = float(raw.info["sfreq"])
359
  x_tc = raw.get_data().T # (T,C)
360
+
361
+ # 電極位置を取得
362
+ electrode_pos = extract_electrode_positions_2d(set_path)
363
+
364
+ return x_tc.astype(np.float32), fs, electrode_pos
365
 
366
  except Exception as e_raw:
367
  # 2) Epochsとして読む(エポックデータ用)
 
373
  # ここは方針を選ぶ:平均 or 連結
374
  x_mean = x.mean(axis=0) # (C,T)
375
  x_tc = x_mean.T # (T,C)
376
+
377
+ # 電極位置を取得(epochsからも取得可能)
378
+ electrode_pos = extract_electrode_positions_2d(set_path)
379
+
380
+ return x_tc.astype(np.float32), fs, electrode_pos
381
 
382
  except Exception as e_ep:
383
  # 3) HDF5形式として読む(MATLAB v7.3)
 
389
  if 'streamlit' in sys.modules:
390
  debug = True
391
  x_tc, fs = _load_eeglab_hdf5(set_path, fdt_path=fdt_path, debug=debug)
392
+
393
+ # HDF5の場合は電極位置を取得できない(参照形式のため)
394
+ electrode_pos = None
395
+
396
+ return x_tc, fs, electrode_pos
397
 
398
  except Exception as e_hdf5:
399
  # すべて失敗した場合
src/streamlit_app.py CHANGED
@@ -107,14 +107,20 @@ def preprocess_all_eeglab(
107
  EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
108
  fsは読み込んだデータのものを使う。
109
  """
110
- x_tc, fs = load_eeglab_tc_from_bytes(
111
  set_bytes=set_bytes,
112
  set_name=set_name,
113
  fdt_bytes=fdt_bytes,
114
  fdt_name=fdt_name,
115
  )
116
  cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high))
117
- return preprocess_tc(x_tc, cfg)
 
 
 
 
 
 
118
 
119
 
120
  @st.cache_data(show_spinner=False)
@@ -351,11 +357,76 @@ def get_cluster_colors(clusters: np.ndarray) -> List[str]:
351
  return colors
352
 
353
 
354
- def make_network_figure(W: np.ndarray, thr: float, use_louvain: bool = True) -> tuple[go.Figure, int]:
355
- C = W.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
356
  angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
357
  xs = np.cos(angles)
358
  ys = np.sin(angles)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
 
360
  edges = threshold_edges(W, thr)
361
  fig = go.Figure()
@@ -380,6 +451,46 @@ def make_network_figure(W: np.ndarray, thr: float, use_louvain: bool = True) ->
380
  r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
381
  return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  # エッジを描画(重みに応じて色と太さを変える)
384
  for (i, j, w) in edges:
385
  # 正規化された重み (0-1)
@@ -422,6 +533,7 @@ def make_network_figure(W: np.ndarray, thr: float, use_louvain: bool = True) ->
422
  mode="markers+text",
423
  text=[f"{k}" for k in range(C)],
424
  textposition="bottom center",
 
425
  marker=dict(
426
  size=14,
427
  color=node_colors,
@@ -435,10 +547,10 @@ def make_network_figure(W: np.ndarray, thr: float, use_louvain: bool = True) ->
435
 
436
  fig.update_layout(
437
  title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
438
- height=500,
439
  xaxis=dict(visible=False),
440
  yaxis=dict(visible=False),
441
- margin=dict(l=10, r=10, t=50, b=10),
442
  paper_bgcolor='rgba(0,0,0,0.9)',
443
  plot_bgcolor='rgba(0,0,0,0.9)',
444
  )
@@ -829,7 +941,7 @@ else:
829
  # 閾値スライダーとネットワーク図の表示
830
  wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
831
 
832
- col_thr1, col_thr2 = st.columns([3, 1])
833
  with col_thr1:
834
  thr = st.slider(
835
  "閾値 (threshold) ※下げるほどエッジが増えます",
@@ -840,15 +952,35 @@ with col_thr1:
840
  )
841
  with col_thr2:
842
  use_louvain = st.checkbox(
843
- "Louvainクラスタリング",
844
  value=True,
845
  disabled=not LOUVAIN_AVAILABLE,
846
  help="ノードの色をコミュニティ検出結果で塗り分けます"
847
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
848
 
849
  net_col1, net_col2 = st.columns([2, 1])
850
  with net_col1:
851
- fig_net, edge_n = make_network_figure(W, float(thr), use_louvain=use_louvain)
 
 
 
 
 
 
852
  st.plotly_chart(fig_net)
853
 
854
  with net_col2:
 
107
  EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
108
  fsは読み込んだデータのものを使う。
109
  """
110
+ x_tc, fs, electrode_pos = load_eeglab_tc_from_bytes(
111
  set_bytes=set_bytes,
112
  set_name=set_name,
113
  fdt_bytes=fdt_bytes,
114
  fdt_name=fdt_name,
115
  )
116
  cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high))
117
+ result = preprocess_tc(x_tc, cfg)
118
+
119
+ # 電極位置を追加
120
+ if electrode_pos is not None:
121
+ result["electrode_pos"] = electrode_pos
122
+
123
+ return result
124
 
125
 
126
  @st.cache_data(show_spinner=False)
 
357
  return colors
358
 
359
 
360
+ def get_electrode_positions(prep: dict) -> np.ndarray:
361
+ """
362
+ 電極位置を取得する。
363
+
364
+ Returns:
365
+ pos: (C, 2) 電極の2D座標 (x, y)
366
+ 取得できない場合は円形配置を返す
367
+ """
368
+ # prepに電極位置が保存されているかチェック
369
+ if "electrode_pos" in prep:
370
+ return prep["electrode_pos"]
371
+
372
+ # デフォルト: 円形配置
373
+ C = prep["raw"].shape[1]
374
  angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
375
  xs = np.cos(angles)
376
  ys = np.sin(angles)
377
+ return np.column_stack([xs, ys])
378
+
379
+
380
+ def get_head_outline() -> dict:
381
+ """
382
+ 脳の輪郭(頭のアウトライン)を生成。
383
+
384
+ Returns:
385
+ outline: {'head': (x, y), 'nose': (x, y), 'ears': [(x_left, y_left), (x_right, y_right)]}
386
+ """
387
+ # 頭の円
388
+ theta = np.linspace(0, 2*np.pi, 100)
389
+ head_x = np.cos(theta)
390
+ head_y = np.sin(theta)
391
+
392
+ # 鼻(上部の三角形)
393
+ nose_x = np.array([0, -0.1, 0.1, 0])
394
+ nose_y = np.array([1.0, 1.15, 1.15, 1.0])
395
+
396
+ # 耳(左右の突起)
397
+ ear_theta = np.linspace(-np.pi/4, np.pi/4, 20)
398
+ ear_left_x = -1.0 + 0.08 * np.cos(ear_theta)
399
+ ear_left_y = 0.08 * np.sin(ear_theta)
400
+
401
+ ear_right_x = 1.0 - 0.08 * np.cos(ear_theta)
402
+ ear_right_y = 0.08 * np.sin(ear_theta)
403
+
404
+ return {
405
+ 'head': (head_x, head_y),
406
+ 'nose': (nose_x, nose_y),
407
+ 'ear_left': (ear_left_x, ear_left_y),
408
+ 'ear_right': (ear_right_x, ear_right_y),
409
+ }
410
+
411
+
412
+ def make_network_figure(
413
+ W: np.ndarray,
414
+ thr: float,
415
+ use_louvain: bool = True,
416
+ electrode_pos: np.ndarray = None,
417
+ show_head: bool = True,
418
+ ) -> tuple[go.Figure, int]:
419
+ C = W.shape[0]
420
+
421
+ # 電極位置を取得
422
+ if electrode_pos is None or electrode_pos.shape[0] != C:
423
+ # デフォルト: 円形配置
424
+ angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
425
+ xs = np.cos(angles)
426
+ ys = np.sin(angles)
427
+ else:
428
+ xs = electrode_pos[:, 0]
429
+ ys = electrode_pos[:, 1]
430
 
431
  edges = threshold_edges(W, thr)
432
  fig = go.Figure()
 
451
  r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
452
  return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
453
 
454
+ # 脳の輪郭を描画
455
+ if show_head:
456
+ outline = get_head_outline()
457
+
458
+ # 頭の円
459
+ fig.add_trace(go.Scatter(
460
+ x=outline['head'][0], y=outline['head'][1],
461
+ mode='lines',
462
+ line=dict(color='rgba(150,150,150,0.5)', width=2),
463
+ showlegend=False,
464
+ hoverinfo='skip',
465
+ ))
466
+
467
+ # 鼻
468
+ fig.add_trace(go.Scatter(
469
+ x=outline['nose'][0], y=outline['nose'][1],
470
+ mode='lines',
471
+ line=dict(color='rgba(150,150,150,0.5)', width=2),
472
+ showlegend=False,
473
+ hoverinfo='skip',
474
+ ))
475
+
476
+ # 左耳
477
+ fig.add_trace(go.Scatter(
478
+ x=outline['ear_left'][0], y=outline['ear_left'][1],
479
+ mode='lines',
480
+ line=dict(color='rgba(150,150,150,0.5)', width=2),
481
+ showlegend=False,
482
+ hoverinfo='skip',
483
+ ))
484
+
485
+ # 右耳
486
+ fig.add_trace(go.Scatter(
487
+ x=outline['ear_right'][0], y=outline['ear_right'][1],
488
+ mode='lines',
489
+ line=dict(color='rgba(150,150,150,0.5)', width=2),
490
+ showlegend=False,
491
+ hoverinfo='skip',
492
+ ))
493
+
494
  # エッジを描画(重みに応じて色と太さを変える)
495
  for (i, j, w) in edges:
496
  # 正規化された重み (0-1)
 
533
  mode="markers+text",
534
  text=[f"{k}" for k in range(C)],
535
  textposition="bottom center",
536
+ textfont=dict(size=8),
537
  marker=dict(
538
  size=14,
539
  color=node_colors,
 
547
 
548
  fig.update_layout(
549
  title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
550
+ height=600,
551
  xaxis=dict(visible=False),
552
  yaxis=dict(visible=False),
553
+ margin=dict(l=10, r=10, t=50, b=50),
554
  paper_bgcolor='rgba(0,0,0,0.9)',
555
  plot_bgcolor='rgba(0,0,0,0.9)',
556
  )
 
941
  # 閾値スライダーとネットワーク図の表示
942
  wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
943
 
944
+ col_thr1, col_thr2, col_thr3 = st.columns([2, 1, 1])
945
  with col_thr1:
946
  thr = st.slider(
947
  "閾値 (threshold) ※下げるほどエッジが増えます",
 
952
  )
953
  with col_thr2:
954
  use_louvain = st.checkbox(
955
+ "Louvainクラスタ",
956
  value=True,
957
  disabled=not LOUVAIN_AVAILABLE,
958
  help="ノードの色をコミュニティ検出結果で塗り分けます"
959
  )
960
+ with col_thr3:
961
+ show_head = st.checkbox(
962
+ "脳の輪郭を表示",
963
+ value=True,
964
+ help="頭部のアウトラインを表示します"
965
+ )
966
+
967
+ # 電極位置を取得
968
+ electrode_pos = prep.get("electrode_pos", None)
969
+
970
+ if electrode_pos is not None:
971
+ st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)")
972
+ else:
973
+ st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します")
974
 
975
  net_col1, net_col2 = st.columns([2, 1])
976
  with net_col1:
977
+ fig_net, edge_n = make_network_figure(
978
+ W,
979
+ float(thr),
980
+ use_louvain=use_louvain,
981
+ electrode_pos=electrode_pos,
982
+ show_head=show_head,
983
+ )
984
  st.plotly_chart(fig_net)
985
 
986
  with net_col2: