File size: 3,645 Bytes
503e48c bc3b19f 503e48c bc3b19f 503e48c bc3b19f 503e48c bc3b19f 1cb4139 bc3b19f 1cb4139 bc3b19f 503e48c bc3b19f 1cb4139 bc3b19f 6f76a60 bc3b19f 503e48c bc3b19f 503e48c bc3b19f 503e48c bc3b19f 503e48c bc3b19f 503e48c bc3b19f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
import shutil
import zipfile
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
# Import your full pipeline
from roi_connectivity import create_rest_epochs, compute_roi_connectivity_matrix
# Output directory
OUTPUT_DIR = Path("connectivity_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def run_connectivity(
npy_file,
band: str = "Low_Beta",
epoch_duration: float = 2.5,
sfreq: float = 500.0
):
subject_output = OUTPUT_DIR / "connectivity_run"
if subject_output.exists():
shutil.rmtree(subject_output)
subject_output.mkdir()
try:
# Save uploaded file
data_path = subject_output / "difumo_time_courses.npy"
shutil.copy(npy_file.name, data_path)
# Create epochs (resting-state only)
epochs = create_rest_epochs(
data_file=str(data_path),
duration=epoch_duration,
sfreq=sfreq
)
print(f"✅ Created {len(epochs)} epochs")
# Compute connectivity matrix
conn_matrix = compute_roi_connectivity_matrix(
epochs=epochs,
band_name=band,
method='wpli2_debiased',
sfreq=sfreq
)
print(f"🌐 Connectivity matrix shape: {conn_matrix.shape}")
# Save CSV
csv_path = subject_output / f"connectivity_{band}.csv"
conn_matrix.to_csv(csv_path)
# Save heatmap
plt.figure(figsize=(20, 15))
sns.heatmap(conn_matrix, cmap="viridis", square=True, cbar_kws={'label': 'WPLI'})
plt.title(f"Source-Level Connectivity: {band} Band")
plt.tight_layout()
heatmap_path = subject_output / f"connectivity_{band}_heatmap.png"
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
plt.close()
# Bundle into ZIP
zip_path = OUTPUT_DIR / f"connectivity_{band}_output.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.write(csv_path, csv_path.name)
zf.write(heatmap_path, heatmap_path.name)
return str(zip_path)
except Exception as e:
error_log = subject_output / "connectivity_error.log"
with open(error_log, "w") as f:
f.write(f"Connectivity failed:\n{str(e)}")
zip_path = OUTPUT_DIR / "connectivity_error.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.write(error_log, error_log.name)
return str(zip_path)
# Gradio UI
with gr.Blocks(theme=gr.themes.Base(), title="Functional Connectivity") as demo:
gr.Markdown("# Functional Connectivity from DiFuMo Time Series")
gr.Markdown("Upload `difumo_time_courses.npy` (from LCMV pipeline). Computes ROI connectivity matrix.")
with gr.Row():
with gr.Column():
npy_input = gr.File(label="DiFuMo Time Courses (.npy)", file_types=[".npy"])
band = gr.Dropdown(
choices=["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"],
value="Low_Beta",
label="Frequency Band"
)
epoch_duration = gr.Number(2.5, label="Epoch Duration (seconds)")
sfreq = gr.Number(500.0, label="Sampling Frequency (Hz)")
run_btn = gr.Button("Compute Connectivity", variant="primary")
with gr.Column():
output_file = gr.File(label="Download Connectivity Results (.zip)")
run_btn.click(
fn=run_connectivity,
inputs=[npy_input, band, epoch_duration, sfreq],
outputs=output_file,
)
demo.launch() |