JayLacoma commited on
Commit
fa07286
·
verified ·
1 Parent(s): 9e257c1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import os
4
+ import tempfile
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from lcmv_class import LCMVSourceEstimator
9
+
10
+ # Predefine global paths (fsaverage already in repo root)
11
+ FS_SRC_FILE = "fsaverage-vol-5mm-src.fif"
12
+ GPS_TEMPLATE = "ghw280_from_egig.gpsc"
13
+
14
+ def run_lcmv_app(ica_fif_file, gpsc_file, roi_indices_str="305,437,223"):
15
+ try:
16
+ # Parse ROI indices
17
+ roi_indices = [int(x.strip()) for x in roi_indices_str.split(",")]
18
+
19
+ # Create temp project dir
20
+ with tempfile.TemporaryDirectory() as tmp_dir:
21
+ tmp_path = Path(tmp_dir)
22
+
23
+ # Copy/move uploaded files
24
+ ica_path = tmp_path / "sub-01_task-ec_ica_cleaned.fif"
25
+ gps_path = tmp_path / "electrodes.gpsc"
26
+
27
+ with open(ica_path, "wb") as f:
28
+ f.write(ica_fif_file.read())
29
+ with open(gps_path, "wb") as f:
30
+ f.write(gpsc_file.read())
31
+
32
+ # Copy global fsaverage source space into temp dir (required by lcmv_class)
33
+ fs_src_dest = tmp_path / "fsaverage-vol-5mm-src.fif"
34
+ os.system(f"cp {FS_SRC_FILE} {fs_src_dest}")
35
+
36
+ # Config
37
+ config = {
38
+ 'project_base': str(tmp_path),
39
+ 'subject_id': 'sub-01',
40
+ 'task': 'ec',
41
+ 'ica_file_path': 'sub-01_task-ec_ica_cleaned.fif',
42
+ 'gpsc_file_path': 'electrodes.gpsc',
43
+ 'reg': 0.01,
44
+ 'n_jobs': 1, # Reduce parallelism for stability
45
+ }
46
+
47
+ # Run LCMV
48
+ estimator = LCMVSourceEstimator(config)
49
+ estimator.run_enhanced_computation()
50
+
51
+ # Run DiFuMo
52
+ difumo_config = {'dimension': 512, 'resolution_mm': 2}
53
+ time_courses, _ = estimator.run_difumo_extraction(difumo_config=difumo_config)
54
+
55
+ # Save time courses
56
+ tc_path = tmp_path / "difumo_time_courses.npy"
57
+ np.save(tc_path, time_courses)
58
+
59
+ # Generate PSD plot
60
+ fig = plot_psd_rois(str(tc_path), roi_indices, sfreq=500.0, cmap_name='viridis_r')
61
+ plot_path = tmp_path / "psd_plot.png"
62
+ fig.savefig(plot_path, dpi=150, bbox_inches='tight')
63
+ plt.close(fig)
64
+
65
+ return str(tc_path), str(plot_path)
66
+
67
+ except Exception as e:
68
+ return None, f"❌ Error: {str(e)}"
69
+
70
+ # Reuse your plot function (slightly modified)
71
+ def plot_psd_rois(file_path, roi_indices, sfreq=500.0, figsize=(9, 4), cmap_name='viridis_r'):
72
+ import numpy as np
73
+ import matplotlib.pyplot as plt
74
+ from scipy import signal
75
+ from nilearn import datasets
76
+
77
+ data = np.load(file_path)
78
+ if data.shape[0] != 512:
79
+ data = data.T
80
+
81
+ labels = datasets.fetch_atlas_difumo(512, 2).labels['difumo_names'].tolist()
82
+ cmap = plt.cm.get_cmap(cmap_name)
83
+ colors = cmap(np.linspace(0, 1, len(roi_indices)))
84
+
85
+ fig, ax = plt.subplots(figsize=figsize)
86
+ freqs = None
87
+
88
+ for i, idx in enumerate(roi_indices):
89
+ ts = data[idx]
90
+ freqs, psd = signal.welch(ts, fs=sfreq, nperseg=int(1.5 * sfreq))
91
+ psd_db = 10 * np.log10(psd)
92
+ name = labels[idx].strip()
93
+ ax.plot(freqs, psd_db, label=name, color=colors[i], linewidth=2)
94
+
95
+ FREQ_BANDS = [
96
+ ('Delta', (0, 4)),
97
+ ('Theta', (4, 8)),
98
+ ('Alpha', (8, 12)),
99
+ ('Low_Beta', (12, 20)),
100
+ ('High_Beta', (20, 30)),
101
+ ('Low_Gamma', (30, 50)),
102
+ ('High_Gamma', (50, 120))
103
+ ]
104
+ blue_shades = plt.cm.Blues(np.linspace(0.9, 0.2, len(FREQ_BANDS)))
105
+ band_colors = {band: blue_shades[i] for i, (band, _) in enumerate(FREQ_BANDS)}
106
+
107
+ if freqs is not None:
108
+ for band, (fmin, fmax) in FREQ_BANDS:
109
+ if fmax < freqs[0] or fmin > freqs[-1]:
110
+ continue
111
+ ax.axvspan(fmin, fmax, color=band_colors[band], alpha=0.3, zorder=0)
112
+
113
+ ax.set_xlim(0, 120)
114
+ ax.set_ylim(-70, -10)
115
+ ax.set_xlabel("Frequency (Hz)")
116
+ ax.set_ylabel("Power Spectral Density (dB/Hz)")
117
+ ax.set_title("Source-Level PSD (LCMV + DiFuMo)", fontweight='bold')
118
+ ax.grid(True, alpha=0.1)
119
+ ax.legend(loc='upper right', fontsize=9, framealpha=0.8)
120
+ plt.tight_layout()
121
+ return fig
122
+
123
+ # Gradio Interface
124
+ with gr.Blocks(title="LCMV_Xtra: EEG Source Imaging") as demo:
125
+ gr.Markdown("# 🧠 LCMV_Xtra: EEG Source Imaging with DiFuMo")
126
+ gr.Markdown("Upload ICA-cleaned EEG (.fif) and electrode (.gpsc) files to compute source-level time courses and PSD.")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ fif_input = gr.File(label="ICA-cleaned EEG (.fif)", file_types=[".fif"])
131
+ gpsc_input = gr.File(label="Electrode positions (.gpsc)", file_types=[".gpsc"])
132
+ roi_input = gr.Textbox(label="DiFuMo ROI Indices (comma-separated)", value="305,437,223")
133
+ run_btn = gr.Button("Run LCMV + DiFuMo")
134
+
135
+ with gr.Column():
136
+ time_course_out = gr.File(label="Download DiFuMo Time Courses (.npy)")
137
+ plot_out = gr.Image(label="Source-Level PSD")
138
+
139
+ run_btn.click(
140
+ fn=run_lcmv_app,
141
+ inputs=[fif_input, gpsc_input, roi_input],
142
+ outputs=[time_course_out, plot_out]
143
+ )
144
+
145
+ gr.Markdown("""
146
+ > ⚠️ **Note**: Processing may take 1–3 minutes. Uses pre-downloaded `fsaverage` and 5mm source space.
147
+ """)
148
+
149
+ demo.launch()
150
+