import os import shutil import zipfile from pathlib import Path import numpy as np import pandas as pd import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import seaborn as sns import gradio as gr # Import your full pipeline from roi_connectivity import create_rest_epochs, compute_roi_connectivity_matrix # Output directory OUTPUT_DIR = Path("connectivity_output") OUTPUT_DIR.mkdir(exist_ok=True) def run_connectivity( npy_file, band: str = "Low_Beta", epoch_duration: float = 2.5, sfreq: float = 500.0 ): subject_output = OUTPUT_DIR / "connectivity_run" if subject_output.exists(): shutil.rmtree(subject_output) subject_output.mkdir() try: # Save uploaded file data_path = subject_output / "difumo_time_courses.npy" shutil.copy(npy_file.name, data_path) # Create epochs (resting-state only) epochs = create_rest_epochs( data_file=str(data_path), duration=epoch_duration, sfreq=sfreq ) print(f"✅ Created {len(epochs)} epochs") # Compute connectivity matrix conn_matrix = compute_roi_connectivity_matrix( epochs=epochs, band_name=band, method='wpli2_debiased', sfreq=sfreq ) print(f"🌐 Connectivity matrix shape: {conn_matrix.shape}") # Save CSV csv_path = subject_output / f"connectivity_{band}.csv" conn_matrix.to_csv(csv_path) # Save heatmap plt.figure(figsize=(20, 15)) sns.heatmap(conn_matrix, cmap="viridis", square=True, cbar_kws={'label': 'WPLI'}) plt.title(f"Source-Level Connectivity: {band} Band") plt.tight_layout() heatmap_path = subject_output / f"connectivity_{band}_heatmap.png" plt.savefig(heatmap_path, dpi=150, bbox_inches='tight') plt.close() # Bundle into ZIP zip_path = OUTPUT_DIR / f"connectivity_{band}_output.zip" with zipfile.ZipFile(zip_path, "w") as zf: zf.write(csv_path, csv_path.name) zf.write(heatmap_path, heatmap_path.name) return str(zip_path) except Exception as e: error_log = subject_output / "connectivity_error.log" with open(error_log, "w") as f: f.write(f"Connectivity failed:\n{str(e)}") zip_path = OUTPUT_DIR / "connectivity_error.zip" with zipfile.ZipFile(zip_path, "w") as zf: zf.write(error_log, error_log.name) return str(zip_path) # Gradio UI with gr.Blocks(theme=gr.themes.Base(), title="Functional Connectivity") as demo: gr.Markdown("# Functional Connectivity from DiFuMo Time Series") gr.Markdown("Upload `difumo_time_courses.npy` (from LCMV pipeline). Computes ROI connectivity matrix.") with gr.Row(): with gr.Column(): npy_input = gr.File(label="DiFuMo Time Courses (.npy)", file_types=[".npy"]) band = gr.Dropdown( choices=["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"], value="Low_Beta", label="Frequency Band" ) epoch_duration = gr.Number(2.5, label="Epoch Duration (seconds)") sfreq = gr.Number(500.0, label="Sampling Frequency (Hz)") run_btn = gr.Button("Compute Connectivity", variant="primary") with gr.Column(): output_file = gr.File(label="Download Connectivity Results (.zip)") run_btn.click( fn=run_connectivity, inputs=[npy_input, band, epoch_duration, sfreq], outputs=output_file, ) demo.launch()