ICA_Xtra / app.py
JayLacoma's picture
Update app.py
4497fc2 verified
import os
import shutil
import zipfile
from pathlib import Path
import gradio as gr
import mne
# Import your pipeline
from ica_xtra import run_preprocessing_pipeline
# Fixed montage file (must be in repo root)
BUILT_IN_GPSC = Path("ghw280_from_egig.gpsc")
if not BUILT_IN_GPSC.is_file():
raise FileNotFoundError(f"Required montage file not found: {BUILT_IN_GPSC}")
# Output directory
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True)
# Fixed subject ID — no user input
SUBJECT_ID = "sub"
def preprocess_eeg(
eeg_file,
input_format: str = "fif",
apply_highpass: bool = True,
apply_lowpass: bool = True,
apply_notch: bool = True,
line_freq: float = 60.0,
pre_ica_mad: float = 3.5,
post_ica_mad: float = 5.0,
interpolate_before_ica: bool = False,
use_artifact_detection: bool = True,
):
# Clean output dir for this fixed subject
subject_dir = OUTPUT_DIR / SUBJECT_ID
if subject_dir.exists():
shutil.rmtree(subject_dir)
subject_dir.mkdir(parents=True)
(subject_dir / "plots").mkdir(exist_ok=True)
uploaded_file = Path(eeg_file.name)
if input_format == "mff":
if not zipfile.is_zipfile(uploaded_file):
raise ValueError("MFF input must be a ZIP file containing the .mff folder.")
eeg_path = subject_dir / f"{SUBJECT_ID}.mff"
with zipfile.ZipFile(uploaded_file, 'r') as zf:
zf.extractall(eeg_path)
if not (eeg_path / "info.xml").exists():
raise ValueError("Invalid MFF: missing 'info.xml' in the ZIP.")
else:
eeg_path = subject_dir / uploaded_file.name
shutil.copy(uploaded_file, eeg_path)
try:
# Run pipeline
run_preprocessing_pipeline(
subject=SUBJECT_ID,
input_path=str(eeg_path),
gpsc_file=str(BUILT_IN_GPSC),
base_output_path=str(OUTPUT_DIR),
input_format=input_format,
apply_highpass=apply_highpass,
apply_lowpass=apply_lowpass,
apply_notch=apply_notch,
line_freq=line_freq,
pre_ica_mad_threshold=pre_ica_mad,
post_ica_mad_threshold=post_ica_mad,
interpolate_before_ica=interpolate_before_ica,
use_artifact_detection_channels=use_artifact_detection,
append_subject_to_output=True,
plot=True,
log_to_file=True,
random_state=99,
)
# 🔧 FIX: Handle multiple cleaned FIF files
cleaned_fifs = list(subject_dir.glob("*_eeg_ica_cleaned_raw.fif"))
if not cleaned_fifs:
raise FileNotFoundError("No cleaned FIF file found after preprocessing.")
# Prefer the one with 'post_ica' in the name (final version)
final_fif = None
for f in cleaned_fifs:
if "post_ica" in f.name:
final_fif = f
break
if final_fif is None:
# If no post_ica file, use the most recently modified (fallback)
final_fif = max(cleaned_fifs, key=os.path.getmtime)
# Delete all cleaned FIFs except the final one
for f in cleaned_fifs:
if f != final_fif:
f.unlink()
# Rename to standard clean name
expected_name = subject_dir / f"{SUBJECT_ID}_eeg_ica_cleaned_raw.fif"
if final_fif != expected_name:
final_fif.rename(expected_name)
cleaned_fif = expected_name
else:
cleaned_fif = final_fif
# Load and save info summary
raw = mne.io.read_raw_fif(str(cleaned_fif), preload=False)
info_summary = subject_dir / f"{SUBJECT_ID}_info_summary.txt"
with open(info_summary, "w") as f:
f.write(f"=== EEG Info Summary for {SUBJECT_ID} ===\n")
f.write(str(raw.info))
# Create ZIP
zip_path = OUTPUT_DIR / f"{SUBJECT_ID}_preprocessing_output.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, files in os.walk(subject_dir):
for file in files:
file_path = Path(root) / file
arcname = file_path.relative_to(OUTPUT_DIR)
zf.write(file_path, arcname)
return str(zip_path)
except Exception as e:
error_log = subject_dir / "error.log"
with open(error_log, "w") as f:
f.write(f"Preprocessing failed:\n{str(e)}")
zip_path = OUTPUT_DIR / f"{SUBJECT_ID}_error.zip"
with zipfile.ZipFile(zip_path, "w") as zf:
zf.write(error_log, error_log.name)
return str(zip_path)
# Gradio UI
with gr.Blocks(theme=gr.themes.Base(), title="EEG Preprocessing") as demo:
gr.Markdown("# EEG Preprocessing Pipeline")
gr.Markdown("Upload an EEG file (.fif or .mff). Uses built-in `ghw280_from_egig.gpsc` montage.")
with gr.Row():
with gr.Column():
eeg_input = gr.File(label="EEG File (.fif or .mff)")
format_dropdown = gr.Dropdown(choices=["fif", "mff"], value="fif", label="Input Format")
with gr.Accordion("Advanced Settings", open=False):
hp = gr.Checkbox(True, label="Apply Highpass (1.0 Hz)")
lp = gr.Checkbox(True, label="Apply Lowpass (100.0 Hz)")
notch = gr.Checkbox(True, label="Apply Notch Filter")
line_freq = gr.Number(60.0, label="Line Frequency (Hz)")
pre_mad = gr.Slider(2.0, 8.0, value=3.5, label="Pre-ICA MAD Threshold")
post_mad = gr.Slider(2.0, 8.0, value=5.0, label="Post-ICA MAD Threshold")
interp_before = gr.Checkbox(False, label="Interpolate Bad Channels Before ICA")
use_artifact = gr.Checkbox(True, label="Use Artifact Detection Channels")
run_btn = gr.Button("Run Preprocessing", variant="primary")
with gr.Column():
output_file = gr.File(label="Download Preprocessing Results (.zip)")
run_btn.click(
fn=preprocess_eeg,
inputs=[
eeg_input,
format_dropdown,
hp,
lp,
notch,
line_freq,
pre_mad,
post_mad,
interp_before,
use_artifact,
],
outputs=output_file,
)
demo.launch()