FC_Xtra / app.py
JayLacoma's picture
Update app.py
bc3b19f verified
import os
import shutil
import zipfile
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
# Import your full pipeline
from roi_connectivity import create_rest_epochs, compute_roi_connectivity_matrix
# Output directory
OUTPUT_DIR = Path("connectivity_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def run_connectivity(
npy_file,
band: str = "Low_Beta",
epoch_duration: float = 2.5,
sfreq: float = 500.0
):
subject_output = OUTPUT_DIR / "connectivity_run"
if subject_output.exists():
shutil.rmtree(subject_output)
subject_output.mkdir()
try:
# Save uploaded file
data_path = subject_output / "difumo_time_courses.npy"
shutil.copy(npy_file.name, data_path)
# Create epochs (resting-state only)
epochs = create_rest_epochs(
data_file=str(data_path),
duration=epoch_duration,
sfreq=sfreq
)
print(f"βœ… Created {len(epochs)} epochs")
# Compute connectivity matrix
conn_matrix = compute_roi_connectivity_matrix(
epochs=epochs,
band_name=band,
method='wpli2_debiased',
sfreq=sfreq
)
print(f"🌐 Connectivity matrix shape: {conn_matrix.shape}")
# Save CSV
csv_path = subject_output / f"connectivity_{band}.csv"
conn_matrix.to_csv(csv_path)
# Save heatmap
plt.figure(figsize=(20, 15))
sns.heatmap(conn_matrix, cmap="viridis", square=True, cbar_kws={'label': 'WPLI'})
plt.title(f"Source-Level Connectivity: {band} Band")
plt.tight_layout()
heatmap_path = subject_output / f"connectivity_{band}_heatmap.png"
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
plt.close()
# Bundle into ZIP
zip_path = OUTPUT_DIR / f"connectivity_{band}_output.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.write(csv_path, csv_path.name)
zf.write(heatmap_path, heatmap_path.name)
return str(zip_path)
except Exception as e:
error_log = subject_output / "connectivity_error.log"
with open(error_log, "w") as f:
f.write(f"Connectivity failed:\n{str(e)}")
zip_path = OUTPUT_DIR / "connectivity_error.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.write(error_log, error_log.name)
return str(zip_path)
# Gradio UI
with gr.Blocks(theme=gr.themes.Base(), title="Functional Connectivity") as demo:
gr.Markdown("# Functional Connectivity from DiFuMo Time Series")
gr.Markdown("Upload `difumo_time_courses.npy` (from LCMV pipeline). Computes ROI connectivity matrix.")
with gr.Row():
with gr.Column():
npy_input = gr.File(label="DiFuMo Time Courses (.npy)", file_types=[".npy"])
band = gr.Dropdown(
choices=["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"],
value="Low_Beta",
label="Frequency Band"
)
epoch_duration = gr.Number(2.5, label="Epoch Duration (seconds)")
sfreq = gr.Number(500.0, label="Sampling Frequency (Hz)")
run_btn = gr.Button("Compute Connectivity", variant="primary")
with gr.Column():
output_file = gr.File(label="Download Connectivity Results (.zip)")
run_btn.click(
fn=run_connectivity,
inputs=[npy_input, band, epoch_duration, sfreq],
outputs=output_file,
)
demo.launch()