JayLacoma commited on
Commit
66d9c69
·
verified ·
1 Parent(s): 60c9728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -24
app.py CHANGED
@@ -10,24 +10,22 @@ from lcmv_class import LCMVSourceEstimator
10
  BUILT_IN_GPSC = Path("ghw280_from_egig.gpsc")
11
  BUILT_IN_SRC = Path("fsaverage-vol-5mm-src.fif")
12
 
13
- # Validate files exist
14
  if not BUILT_IN_GPSC.is_file():
15
  raise FileNotFoundError(f"Required montage file not found: {BUILT_IN_GPSC}")
16
  if not BUILT_IN_SRC.is_file():
17
  raise FileNotFoundError(f"Required source space file not found: {BUILT_IN_SRC}")
18
 
19
- # Prepare fsaverage structure to avoid MNE download
20
  FS_DIR = Path("derivatives/lcmv/fsaverage")
21
  BEM_DIR = FS_DIR / "bem"
22
  BEM_DIR.mkdir(parents=True, exist_ok=True)
23
 
24
- # Copy source space to expected location
25
  EXPECTED_SRC = Path("derivatives/lcmv/fsaverage-vol-5mm-src.fif")
26
  EXPECTED_SRC.parent.mkdir(parents=True, exist_ok=True)
27
  if not EXPECTED_SRC.exists():
28
  shutil.copy(BUILT_IN_SRC, EXPECTED_SRC)
29
 
30
- # Fixed subject/task (no user input)
31
  SUBJECT_ID = "sub"
32
  TASK = ""
33
 
@@ -35,7 +33,8 @@ def run_lcmv(
35
  cleaned_fif,
36
  run_difumo: bool = True,
37
  reg: float = 0.01,
38
- n_jobs: int = 1, # HF Spaces: use 1 to avoid OOM
 
39
  ):
40
  abs_base = Path(".").resolve()
41
  subject_output = abs_base / "derivatives" / "lcmv" / f"{SUBJECT_ID}_{TASK}"
@@ -59,9 +58,7 @@ def run_lcmv(
59
  }
60
 
61
  try:
62
- # Set SUBJECTS_DIR so MNE finds fsaverage
63
  os.environ['SUBJECTS_DIR'] = str(abs_base / "derivatives" / "lcmv")
64
-
65
  estimator = LCMVSourceEstimator(config)
66
  metadata = estimator.run_enhanced_computation()
67
 
@@ -69,25 +66,29 @@ def run_lcmv(
69
  difumo_config = {'dimension': 512, 'resolution_mm': 2}
70
  estimator.run_difumo_extraction(difumo_config=difumo_config)
71
 
72
- # Bundle outputs
73
- zip_path = abs_base / f"{SUBJECT_ID}_{TASK}_lcmv_output.zip"
74
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
75
- for root, _, files in os.walk(subject_output):
76
- for file in files:
77
- file_path = Path(root) / file
78
- arcname = file_path.relative_to(abs_base)
79
- zf.write(file_path, arcname)
80
-
81
- return str(zip_path)
 
 
 
 
 
 
 
82
 
83
  except Exception as e:
84
  error_log = subject_output / "lcmv_error.log"
85
  with open(error_log, "w") as f:
86
  f.write(f"LCMV failed:\n{str(e)}")
87
- zip_path = abs_base / f"{SUBJECT_ID}_{TASK}_lcmv_error.zip"
88
- with zipfile.ZipFile(zip_path, "w") as zf:
89
- zf.write(error_log, error_log.name)
90
- return str(zip_path)
91
 
92
  # Gradio UI
93
  with gr.Blocks(theme=gr.themes.Base(), title="LCMV Source Estimation") as demo:
@@ -100,15 +101,21 @@ with gr.Blocks(theme=gr.themes.Base(), title="LCMV Source Estimation") as demo:
100
  run_difumo = gr.Checkbox(True, label="Extract DiFuMo 512 time courses")
101
  reg = gr.Number(0.01, label="LCMV Regularization (reg)")
102
  n_jobs = gr.Number(1, label="Parallel Jobs (n_jobs)", precision=0)
 
 
 
 
 
 
103
  run_btn = gr.Button("Run LCMV", variant="primary")
104
 
105
  with gr.Column():
106
- output_file = gr.File(label="Download LCMV Results (.zip)")
107
 
108
  run_btn.click(
109
  fn=run_lcmv,
110
- inputs=[fif_input, run_difumo, reg, n_jobs],
111
  outputs=output_file,
112
  )
113
 
114
- demo.launch()
 
10
  BUILT_IN_GPSC = Path("ghw280_from_egig.gpsc")
11
  BUILT_IN_SRC = Path("fsaverage-vol-5mm-src.fif")
12
 
 
13
  if not BUILT_IN_GPSC.is_file():
14
  raise FileNotFoundError(f"Required montage file not found: {BUILT_IN_GPSC}")
15
  if not BUILT_IN_SRC.is_file():
16
  raise FileNotFoundError(f"Required source space file not found: {BUILT_IN_SRC}")
17
 
18
+ # Prepare fsaverage structure
19
  FS_DIR = Path("derivatives/lcmv/fsaverage")
20
  BEM_DIR = FS_DIR / "bem"
21
  BEM_DIR.mkdir(parents=True, exist_ok=True)
22
 
 
23
  EXPECTED_SRC = Path("derivatives/lcmv/fsaverage-vol-5mm-src.fif")
24
  EXPECTED_SRC.parent.mkdir(parents=True, exist_ok=True)
25
  if not EXPECTED_SRC.exists():
26
  shutil.copy(BUILT_IN_SRC, EXPECTED_SRC)
27
 
28
+ # Fixed subject/task
29
  SUBJECT_ID = "sub"
30
  TASK = ""
31
 
 
33
  cleaned_fif,
34
  run_difumo: bool = True,
35
  reg: float = 0.01,
36
+ n_jobs: int = 1,
37
+ download_option: str = "full_zip", # "full_zip" or "difumo_only"
38
  ):
39
  abs_base = Path(".").resolve()
40
  subject_output = abs_base / "derivatives" / "lcmv" / f"{SUBJECT_ID}_{TASK}"
 
58
  }
59
 
60
  try:
 
61
  os.environ['SUBJECTS_DIR'] = str(abs_base / "derivatives" / "lcmv")
 
62
  estimator = LCMVSourceEstimator(config)
63
  metadata = estimator.run_enhanced_computation()
64
 
 
66
  difumo_config = {'dimension': 512, 'resolution_mm': 2}
67
  estimator.run_difumo_extraction(difumo_config=difumo_config)
68
 
69
+ # Dual output logic
70
+ if download_option == "difumo_only":
71
+ difumo_file = subject_output / "difumo_time_courses.npy"
72
+ if difumo_file.exists():
73
+ return str(difumo_file)
74
+ else:
75
+ raise FileNotFoundError("DiFuMo file not found. Ensure 'Extract DiFuMo' is enabled.")
76
+
77
+ else: # full_zip
78
+ zip_path = abs_base / f"{SUBJECT_ID}_{TASK}_lcmv_output.zip"
79
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
80
+ for root, _, files in os.walk(subject_output):
81
+ for file in files:
82
+ file_path = Path(root) / file
83
+ arcname = file_path.relative_to(abs_base)
84
+ zf.write(file_path, arcname)
85
+ return str(zip_path)
86
 
87
  except Exception as e:
88
  error_log = subject_output / "lcmv_error.log"
89
  with open(error_log, "w") as f:
90
  f.write(f"LCMV failed:\n{str(e)}")
91
+ return str(error_log)
 
 
 
92
 
93
  # Gradio UI
94
  with gr.Blocks(theme=gr.themes.Base(), title="LCMV Source Estimation") as demo:
 
101
  run_difumo = gr.Checkbox(True, label="Extract DiFuMo 512 time courses")
102
  reg = gr.Number(0.01, label="LCMV Regularization (reg)")
103
  n_jobs = gr.Number(1, label="Parallel Jobs (n_jobs)", precision=0)
104
+ download_option = gr.Radio(
105
+ choices=["full_zip", "difumo_only"],
106
+ value="full_zip",
107
+ label="Download Option",
108
+ info="• full_zip: All outputs (STC, metadata, plots, DiFuMo)\n• difumo_only: Only difumo_time_courses.npy (for connectivity)"
109
+ )
110
  run_btn = gr.Button("Run LCMV", variant="primary")
111
 
112
  with gr.Column():
113
+ output_file = gr.File(label="Download Output")
114
 
115
  run_btn.click(
116
  fn=run_lcmv,
117
+ inputs=[fif_input, run_difumo, reg, n_jobs, download_option],
118
  outputs=output_file,
119
  )
120
 
121
+ demo.launch()