|
|
import os |
|
|
import shutil |
|
|
import zipfile |
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib |
|
|
matplotlib.use('Agg') |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
from roi_connectivity import create_rest_epochs, compute_roi_connectivity_matrix |
|
|
|
|
|
|
|
|
OUTPUT_DIR = Path("connectivity_output") |
|
|
OUTPUT_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
def run_connectivity( |
|
|
npy_file, |
|
|
band: str = "Low_Beta", |
|
|
epoch_duration: float = 2.5, |
|
|
sfreq: float = 500.0 |
|
|
): |
|
|
subject_output = OUTPUT_DIR / "connectivity_run" |
|
|
if subject_output.exists(): |
|
|
shutil.rmtree(subject_output) |
|
|
subject_output.mkdir() |
|
|
|
|
|
try: |
|
|
|
|
|
data_path = subject_output / "difumo_time_courses.npy" |
|
|
shutil.copy(npy_file.name, data_path) |
|
|
|
|
|
|
|
|
epochs = create_rest_epochs( |
|
|
data_file=str(data_path), |
|
|
duration=epoch_duration, |
|
|
sfreq=sfreq |
|
|
) |
|
|
print(f"β
Created {len(epochs)} epochs") |
|
|
|
|
|
|
|
|
conn_matrix = compute_roi_connectivity_matrix( |
|
|
epochs=epochs, |
|
|
band_name=band, |
|
|
method='wpli2_debiased', |
|
|
sfreq=sfreq |
|
|
) |
|
|
print(f"π Connectivity matrix shape: {conn_matrix.shape}") |
|
|
|
|
|
|
|
|
csv_path = subject_output / f"connectivity_{band}.csv" |
|
|
conn_matrix.to_csv(csv_path) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(20, 15)) |
|
|
sns.heatmap(conn_matrix, cmap="viridis", square=True, cbar_kws={'label': 'WPLI'}) |
|
|
plt.title(f"Source-Level Connectivity: {band} Band") |
|
|
plt.tight_layout() |
|
|
heatmap_path = subject_output / f"connectivity_{band}_heatmap.png" |
|
|
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
zip_path = OUTPUT_DIR / f"connectivity_{band}_output.zip" |
|
|
with zipfile.ZipFile(zip_path, "w") as zf: |
|
|
zf.write(csv_path, csv_path.name) |
|
|
zf.write(heatmap_path, heatmap_path.name) |
|
|
|
|
|
return str(zip_path) |
|
|
|
|
|
except Exception as e: |
|
|
error_log = subject_output / "connectivity_error.log" |
|
|
with open(error_log, "w") as f: |
|
|
f.write(f"Connectivity failed:\n{str(e)}") |
|
|
zip_path = OUTPUT_DIR / "connectivity_error.zip" |
|
|
with zipfile.ZipFile(zip_path, "w") as zf: |
|
|
zf.write(error_log, error_log.name) |
|
|
return str(zip_path) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(), title="Functional Connectivity") as demo: |
|
|
gr.Markdown("# Functional Connectivity from DiFuMo Time Series") |
|
|
gr.Markdown("Upload `difumo_time_courses.npy` (from LCMV pipeline). Computes ROI connectivity matrix.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
npy_input = gr.File(label="DiFuMo Time Courses (.npy)", file_types=[".npy"]) |
|
|
band = gr.Dropdown( |
|
|
choices=["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"], |
|
|
value="Low_Beta", |
|
|
label="Frequency Band" |
|
|
) |
|
|
epoch_duration = gr.Number(2.5, label="Epoch Duration (seconds)") |
|
|
sfreq = gr.Number(500.0, label="Sampling Frequency (Hz)") |
|
|
run_btn = gr.Button("Compute Connectivity", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_file = gr.File(label="Download Connectivity Results (.zip)") |
|
|
|
|
|
run_btn.click( |
|
|
fn=run_connectivity, |
|
|
inputs=[npy_input, band, epoch_duration, sfreq], |
|
|
outputs=output_file, |
|
|
) |
|
|
|
|
|
demo.launch() |