File size: 45,802 Bytes
b11ec91
 
 
 
3705605
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
 
 
 
 
 
5ec4552
 
 
e648c90
 
 
 
5ec4552
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
 
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
 
 
 
e648c90
 
b11ec91
 
 
 
 
 
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
 
 
e648c90
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
e648c90
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ec4552
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
 
5ec4552
 
e648c90
 
 
 
 
 
5ec4552
e648c90
5ec4552
e648c90
 
 
 
5ec4552
e648c90
 
5ec4552
e648c90
 
 
 
 
 
 
 
 
 
5ec4552
e648c90
 
 
 
 
 
5ec4552
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ec4552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
 
e648c90
b11ec91
 
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11ec91
e648c90
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ec4552
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
5ec4552
b11ec91
 
5ec4552
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
 
 
 
 
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
 
 
 
 
 
 
 
 
b11ec91
 
 
 
e648c90
 
b11ec91
e648c90
 
b11ec91
e648c90
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
e648c90
 
b11ec91
 
e648c90
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
e648c90
 
b11ec91
 
e648c90
 
 
 
 
b11ec91
 
e648c90
b11ec91
 
 
 
 
e648c90
 
 
 
b11ec91
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
 
 
 
e648c90
 
 
b11ec91
 
 
e648c90
 
 
 
 
 
 
 
 
 
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e648c90
b11ec91
 
 
 
 
e648c90
b11ec91
 
 
 
5ec4552
b11ec91
 
 
 
5ec4552
 
 
e648c90
 
 
 
 
 
 
5ec4552
 
 
 
 
b11ec91
e648c90
 
 
 
 
 
b11ec91
 
5ec4552
 
 
 
 
 
b11ec91
e648c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3705605
b11ec91
 
e648c90
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
import io
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import streamlit as st
import plotly.graph_objects as go
import mne
from scipy.signal import hilbert

try:
    import community as community_louvain
    import networkx as nx
    LOUVAIN_AVAILABLE = True
except ImportError:
    LOUVAIN_AVAILABLE = False
    st.warning("⚠️ Louvainクラスタリングを使用するには `pip install python-louvain networkx` を実行してください。")

from loader import (
    pick_set_fdt,
    load_eeglab_tc_from_bytes,
    load_mat_candidates,
)

import metrics

st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide")


# ============================================================
# Preprocess config
# ============================================================
@dataclass(frozen=True)
class PreprocessConfig:
    fs: float
    f_low: float
    f_high: float


# ============================================================
# Helpers
# ============================================================
def ensure_tc(x: np.ndarray) -> np.ndarray:
    """Ensure array is (T,C). Accept (T,), (T,C), (C,T) with heuristic transpose."""
    x = np.asarray(x)
    if x.ndim == 1:
        return x[:, None]
    if x.ndim != 2:
        raise ValueError(f"2次元配列のみ対応です: shape={x.shape}")
    T, C = x.shape
    if T <= 256 and C > T:  # heuristic transpose
        x = x.T
    return x

def _quad_bezier_points(p0, p1, c, n=20):
    """2次Bezierを点列にして返す (n点)"""
    ts = np.linspace(0, 1, n)
    pts = (1-ts)[:,None]**2 * p0 + 2*(1-ts)[:,None]*ts[:,None]*c + ts[:,None]**2 * p1
    return pts  # shape (n,2)

def _quad_bezier_point_and_tangent(p0, p1, c, t):
    """2次Bezierの点と接線ベクトル(微分)を返す"""
    # B(t) = (1-t)^2 p0 + 2(1-t)t c + t^2 p1
    pt = (1-t)**2 * p0 + 2*(1-t)*t * c + t**2 * p1
    # B'(t) = 2(1-t)(c-p0) + 2t(p1-c)
    tan = 2*(1-t)*(c-p0) + 2*t*(p1-c)
    return pt, tan

# ============================================================
# Signal processing
# ============================================================
def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
    """Bandpass filter each channel using MNE RawArray. Input/Output: (T,C)."""
    info = mne.create_info(
        ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
        sfreq=float(cfg.fs),
        ch_types="eeg",
    )
    raw = mne.io.RawArray(x_tc.T, info, verbose=False)  # (C,T)
    raw_filt = raw.copy().filter(l_freq=cfg.f_low, h_freq=cfg.f_high, verbose=False)
    return raw_filt.get_data().T.astype(np.float32)


def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
    """Hilbert envelope per channel using SciPy. Input/Output: (T,C)."""
    analytic = hilbert(x_tc, axis=0)
    return np.abs(analytic).astype(np.float32)


def hilbert_phase_tc(x_tc: np.ndarray) -> np.ndarray:
    """Hilbert phase per channel using SciPy. Input/Output: (T,C)."""
    analytic = hilbert(x_tc, axis=0)
    return np.angle(analytic).astype(np.float32)


def preprocess_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> dict:
    """raw(T,C) -> filtered/envelope/phase をまとめて返す"""
    x_tc = ensure_tc(x_tc).astype(np.float32)
    x_filt = bandpass_tc(x_tc, cfg)
    env = hilbert_envelope_tc(x_filt)
    phase = hilbert_phase_tc(x_filt)
    return {
        "fs": float(cfg.fs), 
        "raw": x_tc, 
        "filtered": x_filt, 
        "envelope": env,
        "amplitude": env,  # envelope のエイリアス
        "phase": phase
    }


@st.cache_data(show_spinner=False)
def preprocess_all_eeglab(
    set_bytes: bytes,
    fdt_bytes: bytes,
    set_name: str,
    fdt_name: str,
    f_low: float,
    f_high: float,
) -> dict:
    """
    EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
    fsは読み込んだデータのものを使う。
    """
    x_tc, fs, electrode_pos_2d, electrode_pos_3d = load_eeglab_tc_from_bytes(
        set_bytes=set_bytes,
        set_name=set_name,
        fdt_bytes=fdt_bytes,
        fdt_name=fdt_name,
    )
    cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high))
    result = preprocess_tc(x_tc, cfg)
    
    # 電極位置を追加
    if electrode_pos_2d is not None:
        result["electrode_pos"] = electrode_pos_2d
    if electrode_pos_3d is not None:
        result["electrode_pos_3d"] = electrode_pos_3d
    
    return result


@st.cache_data(show_spinner=False)
def load_mat_candidates_cached(mat_bytes: bytes) -> dict:
    """MAT candidatesをキャッシュ(UI操作で毎回読まない)"""
    return load_mat_candidates(mat_bytes)


# ============================================================
# Viewer
# ============================================================
def window_slice(X_tc: np.ndarray, start_idx: int, end_idx: int, decim: int) -> np.ndarray:
    start_idx = max(0, min(start_idx, X_tc.shape[0] - 1))
    end_idx = max(start_idx + 1, min(end_idx, X_tc.shape[0]))
    decim = max(1, int(decim))
    return X_tc[start_idx:end_idx:decim, :]


def make_timeseries_figure(
    X_tc: np.ndarray,
    selected_channels: List[int],
    fs: float,
    start_sec: float,
    win_sec: float,
    decim: int,
    offset_mode: bool,
    show_rangeslider: bool,
    signal_type: str = "filtered",
) -> go.Figure:
    start_idx = int(round(start_sec * fs))
    end_idx = int(round((start_sec + win_sec) * fs))

    Xw = window_slice(X_tc, start_idx, end_idx, decim)
    Tw = Xw.shape[0]
    t = (np.arange(Tw) * decim + start_idx) / fs

    fig = go.Figure()

    if not selected_channels:
        fig.update_layout(
            title="Timeseries (no channel selected)",
            height=450,
            xaxis_title="time (s)",
            yaxis_title="amplitude",
        )
        return fig

    # 位相データの場合は特別な処理
    is_phase = signal_type == "phase"

    if offset_mode and len(selected_channels) > 1 and not is_phase:
        per_ch_std = np.std(Xw[:, selected_channels], axis=0)
        base = float(np.median(per_ch_std)) if np.isfinite(np.median(per_ch_std)) and np.median(per_ch_std) > 0 else 1.0
        offset = 5.0 * base

        for k, ch in enumerate(selected_channels):
            y = Xw[:, ch] + k * offset
            fig.add_trace(go.Scatter(x=t, y=y, mode="lines", name=f"ch{ch}", line=dict(width=1)))
        ylab = "amplitude (offset)"
    else:
        for ch in selected_channels:
            fig.add_trace(go.Scatter(x=t, y=Xw[:, ch], mode="lines", name=f"ch{ch}", line=dict(width=1)))
        
        if is_phase:
            ylab = "phase (rad)"
        else:
            ylab = "amplitude"

    # rangeslider の高さを考慮して調整
    plot_height = 550 if show_rangeslider else 450
    bottom_margin = 150 if show_rangeslider else 80

    title_text = f"Timeseries: {signal_type}  (window={win_sec:.2f}s, start={start_sec:.2f}s, decim={decim})"

    fig.update_layout(
        title=title_text,
        height=plot_height,
        xaxis_title="time (s)",
        yaxis_title=ylab,
        legend=dict(orientation="h"),
        margin=dict(l=60, r=20, t=80, b=bottom_margin),
    )
    
    # 位相の場合は y軸の範囲を -π ~ π に固定
    if is_phase:
        fig.update_yaxes(range=[-np.pi - 0.5, np.pi + 0.5])
    
    if show_rangeslider:
        fig.update_xaxes(
            rangeslider=dict(
                visible=True,
                thickness=0.05,
            )
        )
    else:
        fig.update_xaxes(rangeslider=dict(visible=False))
    
    return fig


# ============================================================
# Network (multiple methods) + export
# ============================================================
def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray:
    """
    Envelope (amplitude) の Pearson 相関係数を計算。
    Input: X_tc (T, C) - envelope データ
    Output: W (C, C) - 相関係数の絶対値
    """
    X = X_tc - X_tc.mean(axis=0, keepdims=True)
    corr = np.corrcoef(X, rowvar=False)
    W = np.abs(corr)
    np.fill_diagonal(W, 0.0)
    return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray:
    """
    Phase の PLV を計算。
    Input: X_tc (T, C) - phase データ (ラジアン)
    Output: W (C, C) - circular correlation
    
    circular correlationは以下で計算:
    
    """
    T, C = X_tc.shape
    W = np.zeros((C, C), dtype=np.float32)
    
    # 各チャンネルペアについて PLV を計算
    for i in range(C):
        for j in range(i + 1, C):
            #Jammalamadaka–Sengupta circular correlation
            corr = metrics.circular_correlation(X_tc[:, i], X_tc[:, j])
            W[i, j] = corr
            W[j, i] = corr
    
    return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

def estimate_network_phase_PLV(X_tc: np.ndarray, progress) -> np.ndarray:
    """
    Phase の PLV を計算。
    Input: X_tc (T, C) - phase データ (ラジアン)
    Output: W (C, C) - PLV
    
    PLV は以下で計算:
    r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t|
    """
    T, C = X_tc.shape
    W = np.zeros((C, C), dtype=np.float32)
    
    # 各チャンネルペアについて PLV を計算
    tmp_ = 0
    for i in range(C):
        for j in range(i + 1, C):
            # 位相差
            phase_diff = X_tc[:, i] - X_tc[:, j]
            plv = np.abs(np.mean(np.exp(1j * phase_diff)))
            W[i, j] = plv
            W[j, i] = plv
            tmp_ += 1
            progress.progress(tmp_ / (int(C*(C-1)/2)))

    return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)


def estimate_network_pac_tort(X_tc1, X_tc2, progress):
    """
    PACを目的としてModulation Indexを計算
    Input: X_tc1 (T, C) - phase データ (ラジアン)
    Input: X_tc2 (T, C) - envelope データ
    Output: W (C, C) - Modulation Index
    """
    assert X_tc1.shape == X_tc2.shape
    T, C = X_tc1.shape
    W = np.zeros((C, C), dtype=np.float32)
    
    # 各チャンネルペアについて Chatterjee correlation を計算
    tmp_ = 0
    for i in range(C):
        for j in range(C):
            if i == j:
                continue
            # Modulation Index from Tort et al.(2010)
            mi_ = metrics.modulation_index(X_tc1[:, i], X_tc2[:, j])
            W[i, j] = mi_
            tmp_ += 1
            progress.progress(tmp_ / (C*C))

    return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

def estimate_network_pac_chatterjee(X_tc1, X_tc2, progress):
    """
    PACを目的としてChatterjee相関を計算
    Input: X_tc1 (T, C) - phase データ (ラジアン)
    Input: X_tc2 (T, C) - envelope データ
    Output: W (C, C) - Chatterjee correlation from phase to envelope
    """
    assert X_tc1.shape == X_tc2.shape
    T, C = X_tc1.shape
    W = np.zeros((C, C), dtype=np.float32)
    
    # 各チャンネルペアについて Chatterjee correlation を計算
    tmp_ = 0
    for i in range(C):
        for j in range(C):
            if i == j:
                continue
            # Chatterjee相関係数
            corr_ = metrics.chatterjee_phase_to_amp(X_tc1[:, i], X_tc2[:, j])
            W[i, j] = corr_
            tmp_ += 1
            progress.progress(tmp_ / (C*C))

    return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)


def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray:
    """
    ダミー実装: 単純な相関係数の絶対値
    (後方互換性のため残す)
    """
    X = X_tc - X_tc.mean(axis=0, keepdims=True)
    corr = np.corrcoef(X, rowvar=False)
    W = np.abs(corr)
    np.fill_diagonal(W, 0.0)
    return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)


def threshold_edges(
    W: np.ndarray,
    thr: float,
) -> List[Tuple[int, int, float]]:
    """
    エッジ抽出関数

    - W が対称 → 無向グラフとして i < j のみ抽出
    - W が非対称 → 有向グラフとして i -> j をすべて抽出

    Returns:
        (i, j, w): 対称の場合は無向、非対称の場合は i→j
    """
    C = W.shape[0]
    edges: List[Tuple[int, int, float]] = []

    is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0)

    if is_symmetric:
        # --- 無向グラフ ---
        for i in range(C):
            for j in range(i + 1, C):
                w = float(W[i, j])
                if w >= thr:
                    edges.append((i, j, w))
    else:
        # --- 有向グラフ ---
        for i in range(C):
            for j in range(C):
                if i == j:
                    continue
                w = float(W[i, j])
                if w >= thr:
                    edges.append((i, j, w))

    # 重みの大きい順にソート
    edges.sort(key=lambda x: x[2], reverse=True)
    return edges



def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray:
    if weighted:
        A = W.copy()
        A[A < thr] = 0.0
        np.fill_diagonal(A, 0.0)
        return A
    A = (W >= thr).astype(int)
    np.fill_diagonal(A, 0)
    return A


def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray:
    """
    Louvain法でクラスタリングを実行。
    
    Args:
        W: 重み行列 (C, C)
        thr: 閾値(これ以下のエッジは削除)
    
    Returns:
        clusters: クラスタID配列 (C,)
    """
    if not LOUVAIN_AVAILABLE:
        # Louvainが使えない場合は全ノードを同じクラスタに
        return np.zeros(W.shape[0], dtype=int)
    
    # NetworkXグラフを作成
    G = nx.Graph()
    C = W.shape[0]
    G.add_nodes_from(range(C))
    
    # 閾値以上のエッジを追加
    for i in range(C):
        for j in range(C):
            if W[i, j] >= thr:
                G.add_edge(i, j, weight=max(W[i, j],W[j, i]))
    
    # Louvain法でコミュニティ検出
    partition = community_louvain.best_partition(G, weight='weight')
    
    # クラスタIDの配列に変換
    clusters = np.array([partition[i] for i in range(C)])
    
    return clusters


def get_cluster_colors(clusters: np.ndarray) -> List[str]:
    """
    クラスタIDから色のリストを生成。
    
    Args:
        clusters: クラスタID配列 (C,)
    
    Returns:
        colors: 色のリスト
    """
    import colorsys
    
    n_clusters = len(np.unique(clusters))
    
    # クラスタ数に応じて色相を均等に分割
    colors = []
    for cluster_id in clusters:
        hue = cluster_id / max(n_clusters, 1)
        r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.95)
        colors.append(f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})')
    
    return colors


def get_electrode_positions(prep: dict) -> np.ndarray:
    """
    電極位置を取得する。
    
    Returns:
        pos: (C, 2) 電極の2D座標 (x, y)
             取得できない場合は円形配置を返す
    """
    # prepに電極位置が保存されているかチェック
    if "electrode_pos" in prep:
        return prep["electrode_pos"]
    
    # デフォルト: 円形配置
    C = prep["raw"].shape[1]
    angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
    xs = np.cos(angles)
    ys = np.sin(angles)
    return np.column_stack([xs, ys])

def make_network_figure_3d(
    W: np.ndarray,
    thr: float,
    electrode_pos_3d: np.ndarray,
    use_louvain: bool = True,
) -> go.Figure:
    """
    3Dネットワーク図を作成(ドラッグで回転可能)
    """
    C = W.shape[0]
    xs = electrode_pos_3d[:, 0]
    ys = electrode_pos_3d[:, 1]
    zs = electrode_pos_3d[:, 2]
    
    edges = threshold_edges(W, thr)
    fig = go.Figure()
    
    # エッジの重みの範囲を取得
    if edges:
        weights = [w for _, _, w in edges]
        min_w = min(weights)
        max_w = max(weights)
        weight_range = max_w - min_w if max_w > min_w else 1.0
    else:
        min_w = 0
        max_w = 1
        weight_range = 1.0
    
    # レインボーカラーマップ関数
    def get_rainbow_color(norm_val):
        import colorsys
        hue = (1.0 - norm_val) * 0.67
        r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
        return f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})'
    
    # エッジを描画
    for (i, j, w) in edges:
        norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
        color = get_rainbow_color(norm_w)
        line_width = 1 + 4 * norm_w
        
        fig.add_trace(go.Scatter3d(
            x=[xs[i], xs[j], None],
            y=[ys[i], ys[j], None],
            z=[zs[i], zs[j], None],
            mode='lines',
            line=dict(color=color, width=line_width),
            hoverinfo='skip',
            showlegend=False,
        ))
    
    # Louvainクラスタリング
    if use_louvain and LOUVAIN_AVAILABLE:
        clusters = compute_louvain_clusters(W, thr)
        node_colors = get_cluster_colors(clusters)
        n_clusters = len(np.unique(clusters))
        title_suffix = f"  |  Louvain clusters: {n_clusters}"
    else:
        node_colors = ['#FFD700'] * C
        clusters = np.zeros(C, dtype=int)
        title_suffix = ""
    
    # ノードを描画
    fig.add_trace(go.Scatter3d(
        x=xs,
        y=ys,
        z=zs,
        mode='markers+text',
        text=[f"{k}" for k in range(C)],
        textposition='top center',
        textfont=dict(size=8),
        marker=dict(
            size=8,
            color=node_colors,
            line=dict(color='white', width=1),
        ),
        hoverinfo='text',
        hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
        showlegend=False,
    ))
    
    fig.update_layout(
        title=f"3D Network (thr={thr:.3f})  edges={len(edges)}{title_suffix}",
        height=700,
        scene=dict(
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            bgcolor='rgba(0,0,0,0.9)',
        ),
        paper_bgcolor='rgba(0,0,0,0.9)',
        margin=dict(l=0, r=0, t=50, b=0),
    )
    
    return fig


def make_network_figure(
    W: np.ndarray, 
    thr: float, 
    use_louvain: bool = True,
    electrode_pos: np.ndarray = None,
) -> tuple[go.Figure, int]:
    C = W.shape[0]
    
    # 電極位置を取得
    if electrode_pos is None or electrode_pos.shape[0] != C:
        # デフォルト: 円形配置
        angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
        xs = np.cos(angles)
        ys = np.sin(angles)
    else:
        xs = electrode_pos[:, 0]
        ys = electrode_pos[:, 1]

    edges = threshold_edges(W, thr)
    fig = go.Figure()

    # エッジの重みの範囲を取得(色と太さのスケーリング用)
    if edges:
        weights = [w for _, _, w in edges]
        min_w = min(weights)
        max_w = max(weights)
        weight_range = max_w - min_w if max_w > min_w else 1.0
    else:
        min_w = 0
        max_w = 1
        weight_range = 1.0

    # レインボーカラーマップ関数 (0=青 → 0.5=緑/黄 → 1=赤)
    def get_rainbow_color(norm_val):
        """正規化された値 (0-1) からレインボーカラーを生成"""
        import colorsys
        # HSVのHue: 240°(青) → 0°(赤) に変換
        hue = (1.0 - norm_val) * 0.67  # 0.67 ≈ 240/360 (青)
        r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
        return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'

    # エッジを描画(重みに応じて色と太さを変える)
    
    # --- 有向のときだけ:矢印(三角マーカー)を終端側に置く ---
    is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0)
    if (not is_symmetric):
        curve_strength = 0.1   # 湾曲の強さ(要調整)
        node_radius = 0.08      # ノード中心からどれくらい手前に終点/矢印を置くか(要調整)
        bezier_n = 18           # 曲線の分割数(増やすほど滑らか)
        t_arrow = 0.90          # 矢印を置く位置(0〜1)
        for (i, j, w) in edges:
            norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
            color = get_rainbow_color(norm_w)
            line_width = 0.5 + 3.5 * norm_w

            p0 = np.array([xs[i], ys[i]], dtype=float)
            p1 = np.array([xs[j], ys[j]], dtype=float)

            v = p1 - p0
            dist = np.hypot(v[0], v[1])
            if dist < 1e-9:
                continue
            u = v / dist

            # ノードに重ならないよう端点を縮める
            p0s = p0 + u * node_radius
            p1s = p1 - u * node_radius

            # 垂直方向(法線)
            n = np.array([-u[1], u[0]])

            # ★ 有向エッジは全部曲げる(規則的に)
            sign = 1.0 #if i < j else -1.0

            # 制御点
            mid = 0.5 * (p0s + p1s)
            c = mid + sign * curve_strength * dist * n

            # 曲線点列
            pts = _quad_bezier_points(p0s, p1s, c, n=bezier_n)

            fig.add_trace(go.Scatter(
                x=pts[:, 0],
                y=pts[:, 1],
                mode="lines",
                hoverinfo="text",
                hovertext=f"ch{i} → ch{j}<br>weight: {w:.4f}",
                line=dict(width=line_width, color=color),
                showlegend=False,
            ))

            # 矢印(曲線接線方向)
            pt, tan = _quad_bezier_point_and_tangent(p0s, p1s, c, t_arrow)

            # 接線がゼロに近い場合の保険
            tx, ty = float(tan[0]), float(tan[1])
            if tx*tx + ty*ty < 1e-18:
                tx, ty = float(p1s[0] - p0s[0]), float(p1s[1] - p0s[1])

            theta = np.degrees(np.arctan2(ty, tx))  # 接線の角度(+x基準)
            ANGLE_OFFSET = -90.0  # triangle-up(上向き) を接線方向に合わせる補正
            ang = (theta + ANGLE_OFFSET) % 360

            fig.add_trace(go.Scatter(
                x=[pt[0]],
                y=[pt[1]],
                mode="markers",
                hoverinfo="skip",
                marker=dict(
                    symbol="triangle-up",
                    size=10,
                    angle=-ang,
                    angleref="up",
                    color=color,
                    line=dict(width=0),
                ),
                showlegend=False,
            ))
    else:
        for (i, j, w) in edges:
            # 正規化された重み (0-1)
            norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
            
            # レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤)
            color = get_rainbow_color(norm_w)
            
            # 太さ: 重みに比例 (0.5-4の範囲)
            line_width = 0.5 + 3.5 * norm_w
            
            fig.add_trace(
                go.Scatter(
                    x=[xs[i], xs[j]],
                    y=[ys[i], ys[j]],
                    mode="lines",
                    hoverinfo="text",
                    hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}",
                    line=dict(width=line_width, color=color),
                    showlegend=False,
                )
            )


    # Louvainクラスタリング
    if use_louvain and LOUVAIN_AVAILABLE:
        clusters = compute_louvain_clusters(W, thr)
        node_colors = get_cluster_colors(clusters)
        n_clusters = len(np.unique(clusters))
        title_suffix = f"  |  Louvain clusters: {n_clusters}"
    else:
        node_colors = ['#FFD700'] * C  # デフォルトのゴールド
        clusters = np.zeros(C, dtype=int)
        title_suffix = ""

    # ノードを描画
    fig.add_trace(
        go.Scatter(
            x=xs,
            y=ys,
            mode="markers+text",
            text=[f"{k}" for k in range(C)],
            textposition="bottom center",
            textfont=dict(size=8),
            marker=dict(
                size=14, 
                color=node_colors,
                line=dict(width=2, color='white')
            ),
            hoverinfo="text",
            hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
            showlegend=False,
        )
    )

    fig.update_layout(
        title=f"Estimated Network (thr={thr:.3f})  edges={len(edges)}{title_suffix}",
        height=600,
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        margin=dict(l=10, r=10, t=50, b=50),
        paper_bgcolor='rgba(0,0,0,0.9)',
        plot_bgcolor='rgba(0,0,0,0.9)',
    )
    fig.update_yaxes(scaleanchor="x", scaleratio=1)
    
    # カラーバー的な説明を追加
    if edges:
        fig.add_annotation(
            text=f"Edge color/width: weak (blue/thin) → medium (green/yellow) → strong (red/thick)<br>Weight range: {min_w:.3f} - {max_w:.3f}",
            xref="paper", yref="paper",
            x=0.5, y=-0.05,
            showarrow=False,
            font=dict(size=10, color='white'),
            xanchor='center',
        )
    
    return fig, len(edges)


def make_edgecount_curve(W: np.ndarray) -> go.Figure:
    vals = np.sort(W[np.triu_indices(W.shape[0], k=1)])
    thr_grid = np.linspace(float(vals.max()), float(vals.min()), 120) if vals.size else np.array([0.0])
    counts = [len(threshold_edges(W, float(thr))) for thr in thr_grid]

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=thr_grid, y=counts, mode="lines"))
    fig.update_layout(
        title="Edge count vs threshold (lower thr => more edges)",
        xaxis_title="threshold",
        yaxis_title="edge count",
        height=300,
    )
    return fig


def to_csv_bytes_matrix(mat: np.ndarray, fmt: str) -> bytes:
    buf = io.StringIO()
    np.savetxt(buf, mat, delimiter=",", fmt=fmt)
    return buf.getvalue().encode("utf-8")


def to_csv_bytes_edges(edges: List[Tuple[int, int, float]]) -> bytes:
    buf = io.StringIO()
    buf.write("source,target,weight\n")
    for i, j, w in edges:
        buf.write(f"{i},{j},{w:.6f}\n")
    return buf.getvalue().encode("utf-8")


# ============================================================
# Sidebar UI
# ============================================================
st.sidebar.header("Input format")
input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0)

st.sidebar.header("Preprocess (auto)")
f_low_src = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=4.0, step=1.0, key="low_src")
f_high_src = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=8.0, step=1.0, key="high_src")

st.sidebar.header("if you use CFC+PAC:")
f_low_tgt = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=25.0, step=1.0, key="low_tgt")
f_high_tgt = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=40.0, step=1.0, key="high_tgt")

st.sidebar.header("Viewer controls")
win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1)
decim = st.sidebar.selectbox("Decimation (間引き)", options=[1, 2, 5, 10, 20, 50], index=1)
offset_mode = st.sidebar.checkbox("重ね描画のオフセット表示", value=True)
show_rangeslider = st.sidebar.checkbox("Plotly rangesliderを表示", value=False)
signal_view = st.sidebar.radio(
    "表示する信号", 
    ["raw", "filtered", "amplitude", "phase"],
    index=1,
    help="raw: 生信号, filtered: バンドパス後, amplitude: Hilbert振幅(envelope), phase: Hilbert位相"
)

st.title("EEG timeseries viewer + network estimation")


# ============================================================
# Load + preprocess (EEGLAB / MAT)
# ============================================================
if input_mode.startswith("EEGLAB"):
    st.sidebar.header("Upload (.set + .fdt)")
    uploaded_files = st.sidebar.file_uploader(
        "Upload EEGLAB files",
        type=["set", "fdt"],
        accept_multiple_files=True,
    )

    if uploaded_files:
        set_file, fdt_file = pick_set_fdt(uploaded_files)
        if set_file is None or fdt_file is None:
            st.warning("`.set` と `.fdt` の両方をアップロードしてください。")
        else:
            try:
                with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."):
                    prep_src = preprocess_all_eeglab(
                            set_bytes=set_file.getvalue(),
                            fdt_bytes=fdt_file.getvalue(),
                            set_name=set_file.name,
                            fdt_name=fdt_file.name,
                            f_low=float(f_low_src),
                            f_high=float(f_high_src),
                        )
                    prep_tgt = preprocess_all_eeglab(
                        set_bytes=set_file.getvalue(),
                        fdt_bytes=fdt_file.getvalue(),
                        set_name=set_file.name,
                        fdt_name=fdt_file.name,
                        f_low=float(f_low_tgt),
                        f_high=float(f_high_tgt),
                    )
                st.session_state["prep"] = prep_src
                st.session_state["prep_tgt"] = prep_tgt
                st.session_state["W"] = None
                st.success(f"Loaded & preprocessed. (T,C)={prep_src['raw'].shape}  fs={prep_src['fs']:.2f}Hz")
            except Exception as e:
                st.session_state.pop("prep", None)
                st.session_state["W"] = None
                st.error(f"読み込み/前処理エラー: {e}")

else:
    st.sidebar.header("Upload (.mat)")
    mat_file = st.sidebar.file_uploader("Upload .mat", type=["mat"])

    if mat_file is not None:
        mat_bytes = mat_file.getvalue()
        try:
            cands = load_mat_candidates_cached(mat_bytes)
            if not cands:
                st.error("数値の1D/2D配列が見つかりませんでした。")
                st.info("MATファイルの構造を確認しています...")
                
                # デバッグ: MATファイルの中身を表示
                try:
                    from scipy.io import loadmat
                    mat_data = loadmat(io.BytesIO(mat_bytes))
                    st.write("**MATファイルに含まれる変数:**")
                    for k, v in mat_data.items():
                        if not k.startswith('__'):
                            if isinstance(v, np.ndarray):
                                st.write(f"- `{k}`: shape={v.shape}, dtype={v.dtype}, ndim={v.ndim}")
                            else:
                                st.write(f"- `{k}`: type={type(v).__name__}")
                except Exception as e:
                    st.write(f"デバッグ情報の取得に失敗: {e}")
                    
                    # HDF5形式の場合も試す
                    try:
                        import h5py
                        import tempfile
                        with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp:
                            tmp.write(mat_bytes)
                            tmp_path = tmp.name
                        
                        st.write("**HDF5形式として読み込み中...**")
                        with h5py.File(tmp_path, 'r') as f:
                            def show_structure(name, obj):
                                if isinstance(obj, h5py.Dataset):
                                    st.write(f"- `{name}`: shape={obj.shape}, dtype={obj.dtype}")
                            f.visititems(show_structure)
                        
                        import os
                        os.unlink(tmp_path)
                    except Exception as e2:
                        st.write(f"HDF5としても読み込めませんでした: {e2}")
            else:
                key = st.sidebar.selectbox("EEG配列(変数)を選択", options=list(cands.keys()))
                fs_mat = st.sidebar.number_input("Sampling rate (Hz)", min_value=0.1, value=256.0, step=0.1)

                # 変数が選択されたら自動的に前処理を実行
                if key:
                    x = cands[key]
                    st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}")
                    try:
                        with st.spinner("Preprocessing (bandpass + hilbert)..."):
                            cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_src), f_high=float(f_high_src))
                            prep = preprocess_tc(x, cfg)
                            cfg_tgt = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_tgt), f_high=float(f_high_tgt))
                            prep_tgt = preprocess_tc(x, cfg_tgt)

                        st.session_state["prep"] = prep
                        st.session_state["prep_tgt"] = prep_tgt
                        st.session_state["W"] = None
                        st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape}  fs={prep['fs']:.2f}Hz")
                    except Exception as e:
                        st.session_state.pop("prep", None)
                        st.session_state["W"] = None
                        st.error(f"前処理エラー: {e}")
                        import traceback
                        st.code(traceback.format_exc())
        except Exception as e:
            st.session_state.pop("prep", None)
            st.session_state["W"] = None
            st.error(f".mat 読み込みエラー: {e}")
            import traceback
            st.code(traceback.format_exc())


if "prep" not in st.session_state:
    st.info("左のサイドバーからデータをアップロードしてください。")
    st.stop()


# ============================================================
# Viewer
# ============================================================
prep = st.session_state["prep"]
fs = float(prep["fs"])
X_tc = prep[signal_view]
T, C = X_tc.shape

duration_sec = (T - 1) / fs if T > 1 else 0.0
max_start = max(0.0, float(duration_sec - win_sec))

start_sec = st.sidebar.slider(
    "Start time (sec)",
    min_value=0.0,
    max_value=float(max_start),
    value=0.0,
    step=float(max(0.01, win_sec / 200)),
)

st.sidebar.header("Channels")

# チャンネル選択の便利機能
col_ch1, col_ch2 = st.sidebar.columns(2)
with col_ch1:
    select_all = st.button("全選択")
with col_ch2:
    deselect_all = st.button("全解除")

# 範囲選択
with st.sidebar.expander("📊 範囲で選択"):
    range_start = st.number_input("開始ch", min_value=0, max_value=C-1, value=0, step=1)
    range_end = st.number_input("終了ch", min_value=0, max_value=C-1, value=min(C-1, 7), step=1)
    if st.button("範囲を選択"):
        st.session_state["selected_channels"] = list(range(int(range_start), int(range_end) + 1))

# プリセット選択
with st.sidebar.expander("⚡ プリセット"):
    preset_col1, preset_col2 = st.columns(2)
    with preset_col1:
        if st.button("前頭部 (0-15)"):
            st.session_state["selected_channels"] = list(range(min(16, C)))
    with preset_col2:
        if st.button("頭頂部 (16-31)"):
            st.session_state["selected_channels"] = list(range(16, min(32, C)))
    preset_col3, preset_col4 = st.columns(2)
    with preset_col3:
        if st.button("側頭部 (32-47)"):
            st.session_state["selected_channels"] = list(range(32, min(48, C)))
    with preset_col4:
        if st.button("後頭部 (48-63)"):
            st.session_state["selected_channels"] = list(range(48, min(64, C)))

# セッションステートの初期化
if "selected_channels" not in st.session_state:
    st.session_state["selected_channels"] = list(range(min(C, 8)))

# ボタンによる選択の処理
if select_all:
    st.session_state["selected_channels"] = list(range(C))
if deselect_all:
    st.session_state["selected_channels"] = []

# メインの選択UI(最大表示数を制限)
max_display = 20  # multiselect で一度に表示する数を制限
if C <= max_display:
    selected_channels = st.sidebar.multiselect(
        f"表示するチャンネル(全{C}ch)",
        options=list(range(C)),
        default=st.session_state["selected_channels"],
        key="ch_select",
    )
else:
    # 大量のチャンネルがある場合は、選択済みのものだけ表示
    st.sidebar.caption(f"選択中: {len(st.session_state['selected_channels'])} / {C} channels")
    
    # 個別追加
    add_ch = st.sidebar.number_input(
        "チャンネルを追加", 
        min_value=0, 
        max_value=C-1, 
        value=0, 
        step=1,
        key="add_ch_input"
    )
    col_add, col_remove = st.sidebar.columns(2)
    with col_add:
        if st.button("➕ 追加"):
            if add_ch not in st.session_state["selected_channels"]:
                st.session_state["selected_channels"].append(int(add_ch))
                st.session_state["selected_channels"].sort()
    with col_remove:
        if st.button("➖ 削除"):
            if add_ch in st.session_state["selected_channels"]:
                st.session_state["selected_channels"].remove(int(add_ch))
    
    # 現在の選択を表示
    if st.session_state["selected_channels"]:
        selected_str = ", ".join(map(str, st.session_state["selected_channels"][:10]))
        if len(st.session_state["selected_channels"]) > 10:
            selected_str += f", ... (+{len(st.session_state['selected_channels']) - 10})"
        st.sidebar.text(f"選択済み: {selected_str}")
    
    selected_channels = st.session_state["selected_channels"]

# セッションステートを更新(multiselectを使った場合)
if C <= max_display:
    st.session_state["selected_channels"] = selected_channels

col1, col2 = st.columns([2, 1])
with col1:
    fig_ts = make_timeseries_figure(
        X_tc=X_tc,
        selected_channels=selected_channels,
        fs=fs,
        start_sec=float(start_sec),
        win_sec=float(win_sec),
        decim=int(decim),
        offset_mode=bool(offset_mode),
        show_rangeslider=bool(show_rangeslider),
        signal_type=signal_view,
    )
    st.plotly_chart(fig_ts)

with col2:
    st.subheader("Data info")
    signal_desc = {
        "raw": "生信号(前処理なし)",
        "filtered": f"バンドパスフィルタ後 ({f_low_src}-{f_high_src} Hz)",
        "amplitude": "Hilbert振幅 (envelope)",
        "phase": "Hilbert位相 (-π ~ π)"
    }
    st.write(f"- view: **{signal_view}** ({signal_desc.get(signal_view, '')})")
    st.write(f"- fs: **{fs:.2f} Hz**")
    st.write(f"- T: {T} samples")
    st.write(f"- C: {C} channels")
    st.write(f"- duration: {duration_sec:.2f} sec")
    
    if signal_view == "phase":
        st.caption("※ 位相は -π (rad) から π (rad) の範囲で表示されます")
    
    st.caption("※ 大規模データは window + decimation 推奨。rangesliderは重い場合OFF。")

st.divider()


# ============================================================
# Estimation
# ============================================================
st.subheader("Network estimation")

# 推定手法の選択
estimation_method = st.radio(
    "推定手法を選択",
    options=[
        "envelope_corr",
        "phase_PLV",
        "phase_corr",
        "pac_tort",
        "pac_chatterjee"
    ],
    format_func=lambda x: {
        "envelope_corr": "Envelope Pearson correlation (振幅の相関)",
        "phase_PLV": "PLV(位相同期, PLV)",
        "phase_corr": "Circular correlation",
        "pac_tort": "Modulation Index(位相と振幅のPAC指標)",
        "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)",
    }[x],
    horizontal=True,
    help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_PLV: 位相のPhase Locking Value | phase_corr: 位相の相関係数 | pac_tort: Modulation index | pac_chatterjee: 位相から振幅へのChatterjee相関",
)

# 推定手法の説明
method_info = {
    "envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。",
    "phase_PLV": "**PLV**: 位相間のPhase locking valueを計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
    "phase_corr": "**Circular correlation**: 位相間の相関係数を計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
    "pac_tort": "Modulation Index(位相と振幅のPAC指標)",
    "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)",
}
st.info(method_info[estimation_method])

# セッションステートから前回の手法と W を取得
last_method = st.session_state.get("last_estimation_method")
W = st.session_state.get("W")

# 推定が必要かチェック(初回 or 手法変更)
need_estimation = (W is None) or (last_method != estimation_method)

if need_estimation:
    progress = st.progress(0.0)
    with st.spinner(f"推定中... ({estimation_method})"):
        if estimation_method == "envelope_corr":
            X_in = prep["amplitude"]
            W = estimate_network_envelope_corr(X_in)
        elif estimation_method == "phase_PLV":
            X_in = prep["phase"]
            W = estimate_network_phase_PLV(X_in, progress)
        elif estimation_method == "phase_corr":
            X_in = prep["phase"]
            W = estimate_network_phase_corr(X_in)
        elif estimation_method == "pac_tort":
            X_in_low_phase = prep["phase"]
            prep_tgt = st.session_state["prep_tgt"]
            X_in_high_amplitude = prep_tgt["amplitude"]
            W = estimate_network_pac_tort(X_in_low_phase,X_in_high_amplitude,progress)
        elif estimation_method == "pac_chatterjee":
            X_in_low_phase = prep["phase"]
            prep_tgt = st.session_state["prep_tgt"]
            X_in_high_amplitude = prep_tgt["amplitude"]
            W = estimate_network_pac_chatterjee(X_in_low_phase,X_in_high_amplitude,progress)
        else:
            st.error("未知の推定手法です")
            st.stop()
        
        # セッションステートに保存
        st.session_state["W"] = W
        st.session_state["last_estimation_method"] = estimation_method
        st.success(f"✅ 推定完了: {estimation_method} (ネットワークサイズ: {W.shape[0]} nodes)")
else:
    st.success(f"✓ 推定済み: **{estimation_method}** (ネットワークサイズ: {W.shape[0]} nodes)")

# この時点で W は必ず存在する
# 閾値スライダーとネットワーク図の表示
wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0

col_thr1, col_thr2 = st.columns([3, 1])
with col_thr1:
    thr = st.slider(
        "閾値 (threshold)  ※下げるほどエッジが増えます",
        min_value=0.0,
        max_value=max(0.0001, wmax),
        value=wmax/2,
        step=max(wmax / 200, 0.001),
    )
with col_thr2:
    use_louvain = st.checkbox(
        "Louvainクラスタ", 
        value=True, 
        disabled=not LOUVAIN_AVAILABLE,
        help="ノードの色をコミュニティ検出結果で塗り分けます"
    )

# 電極位置を取得
electrode_pos = prep.get("electrode_pos", None)
# 2D座標を90度左回転(上が正面になる向きに)
if electrode_pos is not None:
    electrode_pos = np.asarray(electrode_pos, dtype=np.float32)
    if electrode_pos.ndim == 2 and electrode_pos.shape[1] >= 2:
        pos2 = electrode_pos[:, :2]
        electrode_pos = np.column_stack([-pos2[:, 1], pos2[:, 0]])
electrode_pos_3d = prep.get("electrode_pos_3d", None)

if electrode_pos is not None:
    st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)")
else:
    st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します")

# 3D座標の有無を表示
if electrode_pos_3d is not None:
    st.success(f"✓ 3D電極座標を取得しました ({electrode_pos_3d.shape[0]} channels) - 下部に3Dビューアを表示します")
else:
    st.info("ℹ️ 3D電極座標が取得できませんでした - 2D表示のみです")

net_col1, net_col2 = st.columns([2, 1])
with net_col1:
    fig_net, edge_n = make_network_figure(
        W, 
        float(thr), 
        use_louvain=use_louvain, 
        electrode_pos=electrode_pos,
    )
    st.plotly_chart(fig_net)
    # 3Dネットワーク表示(3D座標がある場合のみ)
    if electrode_pos_3d is not None:
        electrode_pos_3d = np.asarray(electrode_pos_3d, dtype=np.float32)
        if electrode_pos_3d.ndim == 2 and electrode_pos_3d.shape[0] == W.shape[0] and electrode_pos_3d.shape[1] == 3:
            st.subheader("3D Viewer")
            fig_3d = make_network_figure_3d(
                W=W,
                thr=float(thr),
                electrode_pos_3d=electrode_pos_3d,
                use_louvain=use_louvain,
            )
            st.plotly_chart(
                fig_3d,
                width="stretch",
                config={"displayModeBar": True, "scrollZoom": True},
            )
        else:
            st.warning(f"3D座標のshapeが不正です: {electrode_pos_3d.shape}(期待: (C,3), C={W.shape[0]})")

with net_col2:
    st.metric("Edges", edge_n)
    st.plotly_chart(make_edgecount_curve(W))


st.write("# Hypothesis testing")
st.write("Coming soon ...")