File size: 3,645 Bytes
503e48c
bc3b19f
 
 
503e48c
 
bc3b19f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503e48c
bc3b19f
 
 
 
 
503e48c
bc3b19f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb4139
bc3b19f
 
 
1cb4139
bc3b19f
 
 
 
 
 
 
 
503e48c
bc3b19f
 
 
 
 
1cb4139
bc3b19f
 
 
 
 
 
 
 
 
 
 
 
 
6f76a60
bc3b19f
503e48c
 
 
bc3b19f
 
 
503e48c
 
 
bc3b19f
 
 
503e48c
 
bc3b19f
503e48c
 
bc3b19f
 
 
503e48c
 
bc3b19f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import shutil
import zipfile
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr

# Import your full pipeline
from roi_connectivity import create_rest_epochs, compute_roi_connectivity_matrix

# Output directory
OUTPUT_DIR = Path("connectivity_output")
OUTPUT_DIR.mkdir(exist_ok=True)

def run_connectivity(
    npy_file,
    band: str = "Low_Beta",
    epoch_duration: float = 2.5,
    sfreq: float = 500.0
):
    subject_output = OUTPUT_DIR / "connectivity_run"
    if subject_output.exists():
        shutil.rmtree(subject_output)
    subject_output.mkdir()

    try:
        # Save uploaded file
        data_path = subject_output / "difumo_time_courses.npy"
        shutil.copy(npy_file.name, data_path)

        # Create epochs (resting-state only)
        epochs = create_rest_epochs(
            data_file=str(data_path),
            duration=epoch_duration,
            sfreq=sfreq
        )
        print(f"✅ Created {len(epochs)} epochs")

        # Compute connectivity matrix
        conn_matrix = compute_roi_connectivity_matrix(
            epochs=epochs,
            band_name=band,
            method='wpli2_debiased',
            sfreq=sfreq
        )
        print(f"🌐 Connectivity matrix shape: {conn_matrix.shape}")

        # Save CSV
        csv_path = subject_output / f"connectivity_{band}.csv"
        conn_matrix.to_csv(csv_path)

        # Save heatmap
        plt.figure(figsize=(20, 15))
        sns.heatmap(conn_matrix, cmap="viridis", square=True, cbar_kws={'label': 'WPLI'})
        plt.title(f"Source-Level Connectivity: {band} Band")
        plt.tight_layout()
        heatmap_path = subject_output / f"connectivity_{band}_heatmap.png"
        plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
        plt.close()

        # Bundle into ZIP
        zip_path = OUTPUT_DIR / f"connectivity_{band}_output.zip"
        with zipfile.ZipFile(zip_path, "w") as zf:
            zf.write(csv_path, csv_path.name)
            zf.write(heatmap_path, heatmap_path.name)

        return str(zip_path)

    except Exception as e:
        error_log = subject_output / "connectivity_error.log"
        with open(error_log, "w") as f:
            f.write(f"Connectivity failed:\n{str(e)}")
        zip_path = OUTPUT_DIR / "connectivity_error.zip"
        with zipfile.ZipFile(zip_path, "w") as zf:
            zf.write(error_log, error_log.name)
        return str(zip_path)

# Gradio UI
with gr.Blocks(theme=gr.themes.Base(), title="Functional Connectivity") as demo:
    gr.Markdown("# Functional Connectivity from DiFuMo Time Series")
    gr.Markdown("Upload `difumo_time_courses.npy` (from LCMV pipeline). Computes ROI connectivity matrix.")

    with gr.Row():
        with gr.Column():
            npy_input = gr.File(label="DiFuMo Time Courses (.npy)", file_types=[".npy"])
            band = gr.Dropdown(
                choices=["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"],
                value="Low_Beta",
                label="Frequency Band"
            )
            epoch_duration = gr.Number(2.5, label="Epoch Duration (seconds)")
            sfreq = gr.Number(500.0, label="Sampling Frequency (Hz)")
            run_btn = gr.Button("Compute Connectivity", variant="primary")

        with gr.Column():
            output_file = gr.File(label="Download Connectivity Results (.zip)")

    run_btn.click(
        fn=run_connectivity,
        inputs=[npy_input, band, epoch_duration, sfreq],
        outputs=output_file,
    )

demo.launch()