|
|
|
|
|
|
|
|
|
|
|
import mne |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import nibabel as nib |
|
|
from tqdm import tqdm |
|
|
import os, time, pickle |
|
|
from pathlib import Path |
|
|
import matplotlib.pyplot as plt |
|
|
from nilearn import datasets, image |
|
|
|
|
|
|
|
|
mne.set_log_level('warning') |
|
|
|
|
|
class LCMVSourceEstimator: |
|
|
def __init__(self, config): |
|
|
""" |
|
|
Initialize the LCMV Source Estimator with configuration. |
|
|
|
|
|
Parameters: |
|
|
config (dict): Configuration dictionary containing all necessary parameters |
|
|
""" |
|
|
self.config = config |
|
|
self.project_base = Path(config['project_base']) |
|
|
self.subject_id = config['subject_id'] |
|
|
self.task = config['task'] |
|
|
|
|
|
|
|
|
self.global_subjects_dir = self.project_base / 'derivatives/lcmv' |
|
|
|
|
|
|
|
|
self.subject_output = self.project_base / f'derivatives/lcmv/{self.subject_id}_{self.task}' |
|
|
self.subject_output.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def parse_gpsc(self, filepath): |
|
|
"""Parse .gpsc file and normalize coordinates to center the origin.""" |
|
|
channels = [] |
|
|
with open(filepath, 'r') as file: |
|
|
lines = file.readlines() |
|
|
for line in lines: |
|
|
parts = line.strip().split() |
|
|
if len(parts) < 4: |
|
|
continue |
|
|
name = parts[0] |
|
|
try: |
|
|
x, y, z = map(float, parts[1:4]) |
|
|
channels.append((name, x, y, z)) |
|
|
except ValueError: |
|
|
continue |
|
|
return channels |
|
|
|
|
|
def run_enhanced_computation(self): |
|
|
"""Run the complete enhanced LCMV pipeline with improved coregistration""" |
|
|
print("="*60) |
|
|
print(f"π― ENHANCED LCMV SOURCE ESTIMATION - Subject: {self.subject_id}") |
|
|
print("="*60) |
|
|
|
|
|
print("\n=== Loading Data ===") |
|
|
ica_file = self.project_base / self.config['ica_file_path'] |
|
|
gpsc_file = self.project_base / self.config['gpsc_file_path'] |
|
|
|
|
|
if not ica_file.exists(): |
|
|
raise FileNotFoundError(f"ICA file not found: {ica_file}") |
|
|
if not gpsc_file.exists(): |
|
|
raise FileNotFoundError(f"GPSC file not found: {gpsc_file}") |
|
|
|
|
|
|
|
|
raw = mne.io.read_raw_fif(ica_file, preload=True) |
|
|
sfreq = raw.info['sfreq'] |
|
|
duration_min = raw.n_times / sfreq / 60 |
|
|
print(f"Data: {duration_min:.1f}min, {sfreq}Hz, {raw.n_times} samples") |
|
|
|
|
|
|
|
|
print("\n=== Enhanced Preprocessing Pipeline ===") |
|
|
|
|
|
|
|
|
channel_map = {str(i): f'E{i}' for i in range(1, 281)} |
|
|
channel_map['REF CZ'] = 'Cz' |
|
|
|
|
|
|
|
|
existing_channels = set(raw.info['ch_names']) |
|
|
valid_channel_map = {} |
|
|
for old_name, new_name in channel_map.items(): |
|
|
if old_name in existing_channels: |
|
|
valid_channel_map[old_name] = new_name |
|
|
|
|
|
if valid_channel_map: |
|
|
raw.rename_channels(valid_channel_map) |
|
|
print(f"Renamed {len(valid_channel_map)} channels") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n=== Creating Enhanced Montage with Coordinate Normalization ===") |
|
|
|
|
|
|
|
|
channels = self.parse_gpsc(gpsc_file) |
|
|
|
|
|
if not channels: |
|
|
raise ValueError("No valid channels found in .gpsc file") |
|
|
|
|
|
|
|
|
gpsc_array = np.array([ch[1:4] for ch in channels]) |
|
|
mean_pos = np.mean(gpsc_array, axis=0) |
|
|
print(f"Original mean position (mm): {mean_pos}") |
|
|
|
|
|
|
|
|
channels_normalized = [(ch[0], ch[1] - mean_pos[0], ch[2] - mean_pos[1], ch[3] - mean_pos[2]) |
|
|
for ch in channels] |
|
|
ch_pos = {ch[0]: np.array(ch[1:4]) / 1000.0 for ch in channels_normalized} |
|
|
|
|
|
|
|
|
required_fids = ['FidNz', 'FidT9', 'FidT10'] |
|
|
missing = [fid for fid in required_fids if fid not in ch_pos] |
|
|
if missing: |
|
|
raise ValueError(f"Missing fiducials: {missing}") |
|
|
|
|
|
|
|
|
montage = mne.channels.make_dig_montage( |
|
|
ch_pos=ch_pos, |
|
|
nasion=ch_pos['FidNz'], |
|
|
lpa=ch_pos['FidT9'], |
|
|
rpa=ch_pos['FidT10'], |
|
|
coord_frame='head' |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
raw.set_montage(montage, on_missing='warn') |
|
|
raw = raw.pick(['eeg', 'stim'], exclude=raw.info['bads']) |
|
|
|
|
|
print("\nπ Checking EEG reference status...") |
|
|
print(f"custom_ref_applied: {raw.info['custom_ref_applied']}") |
|
|
print(f"n_projs: {len(raw.info['projs'])}") |
|
|
print(f"proj_applied: {raw.proj}") |
|
|
|
|
|
|
|
|
if not any(p['desc'] == 'average' for p in raw.info['projs']): |
|
|
print("π No average reference projection found. Applying it...") |
|
|
raw.set_eeg_reference('average', projection=True) |
|
|
else: |
|
|
print("β
Average reference projection already in place.") |
|
|
|
|
|
|
|
|
if not raw.proj: |
|
|
print("π― Applying EEG average reference projection...") |
|
|
raw.apply_proj() |
|
|
else: |
|
|
print("π‘ Projections already applied.") |
|
|
|
|
|
print("β Enhanced preprocessing complete (reference now valid for inverse modeling)") |
|
|
|
|
|
print(f"Enhanced montage applied:") |
|
|
print(f"FidNz (nasion): {ch_pos['FidNz']}") |
|
|
print(f"FidT9 (lpa): {ch_pos['FidT9']}") |
|
|
print(f"FidT10 (rpa): {ch_pos['FidT10']}") |
|
|
|
|
|
|
|
|
print("\n=== Source Space Setup ===") |
|
|
subject = 'fsaverage' |
|
|
|
|
|
|
|
|
bem_file = self.global_subjects_dir / 'fsaverage' / 'bem' / 'fsaverage-5120-5120-5120-bem-sol.fif' |
|
|
bem_head = self.global_subjects_dir / 'fsaverage' / 'bem' / 'fsaverage-head-dense.fif' |
|
|
src_file = self.global_subjects_dir / 'fsaverage-vol-5mm-src.fif' |
|
|
|
|
|
if not bem_file.exists() or not bem_head.exists(): |
|
|
print("Downloading fsaverage to GLOBAL directory...") |
|
|
mne.datasets.fetch_fsaverage(subjects_dir=self.global_subjects_dir, verbose=False) |
|
|
|
|
|
|
|
|
print("\n=== Running Enhanced Coregistration ===") |
|
|
trans_file = self.subject_output / 'fsaverage-trans.fif' |
|
|
|
|
|
try: |
|
|
|
|
|
coreg = mne.coreg.Coregistration( |
|
|
raw.info, |
|
|
subject=subject, |
|
|
subjects_dir=self.global_subjects_dir, |
|
|
fiducials={ |
|
|
'nasion': ch_pos['FidNz'], |
|
|
'lpa': ch_pos['FidT9'], |
|
|
'rpa': ch_pos['FidT10'] |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
print("1/3: Fitting with fiducials...") |
|
|
coreg.fit_fiducials(verbose=False) |
|
|
|
|
|
|
|
|
print("2/3: Using EEG channels as head shape points for ICP...") |
|
|
coreg.fit_icp(n_iterations=6, nasion_weight=2.0, verbose=False) |
|
|
|
|
|
|
|
|
print(" Removing outlier points...") |
|
|
dists = coreg.compute_dig_mri_distances() |
|
|
n_excluded = np.sum(dists > 5.0/1000) |
|
|
|
|
|
if n_excluded > 0: |
|
|
print(f" Excluding {n_excluded} outlier points (distance > 5mm)") |
|
|
coreg.omit_head_shape_points(distance=5.0/1000) |
|
|
else: |
|
|
print(" No outlier points to exclude") |
|
|
|
|
|
|
|
|
print("3/3: Final ICP refinement...") |
|
|
coreg.fit_icp(n_iterations=20, nasion_weight=10.0, verbose=False) |
|
|
|
|
|
|
|
|
trans = coreg.trans |
|
|
mne.write_trans(trans_file, trans, overwrite=True) |
|
|
print(f"β Enhanced coregistration successful: {trans_file}") |
|
|
|
|
|
|
|
|
dists = coreg.compute_dig_mri_distances() * 1000 |
|
|
mean_err = np.mean(dists) |
|
|
median_err = np.median(dists) |
|
|
max_err = np.max(dists) |
|
|
|
|
|
print(f"\nCoregistration Error (mm):") |
|
|
print(f"Mean: {mean_err:.2f}, Median: {median_err:.2f}, Max: {max_err:.2f}") |
|
|
|
|
|
if mean_err > 5.0: |
|
|
print(f"β οΈ WARNING: Mean error = {mean_err:.2f}mm > 5mm") |
|
|
else: |
|
|
print("β
Enhanced coregistration error acceptable") |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
print(f"β Coregistration failed irrecoverably: {e}") |
|
|
raise RuntimeError(f"Coregistration failed: {e}") |
|
|
|
|
|
|
|
|
print("\n=== Creating Source Space ===") |
|
|
if not src_file.exists(): |
|
|
print("Creating volume source space...") |
|
|
|
|
|
src = mne.setup_volume_source_space( |
|
|
subject=subject, |
|
|
subjects_dir=self.global_subjects_dir, |
|
|
pos=5.0, |
|
|
mri='T1.mgz', |
|
|
add_interpolator=True |
|
|
) |
|
|
src.save(src_file, overwrite=True) |
|
|
else: |
|
|
src = mne.read_source_spaces(src_file) |
|
|
|
|
|
print(f"Source space: {len(src[0]['vertno'])} active sources out of {src[0]['np']} total points") |
|
|
|
|
|
|
|
|
print("\n=== Creating Forward Solution ===") |
|
|
fwd_file = self.subject_output / 'fsaverage-vol-eeg-fwd.fif' |
|
|
bem = mne.read_bem_solution(bem_file) |
|
|
fwd = mne.make_forward_solution( |
|
|
raw.info, trans=trans, src=src, bem=bem, eeg=True, mindist=5.0, n_jobs=self.config['n_jobs'] |
|
|
) |
|
|
mne.write_forward_solution(fwd_file, fwd, overwrite=True) |
|
|
print("β Enhanced source space setup complete") |
|
|
|
|
|
|
|
|
print("\n=== LCMV Beamformer ===") |
|
|
|
|
|
|
|
|
print("Computing single covariance from entire recording...") |
|
|
|
|
|
cov = mne.compute_raw_covariance( |
|
|
raw, |
|
|
method='oas', |
|
|
picks='eeg', |
|
|
rank='info', |
|
|
n_jobs=self.config['n_jobs'], |
|
|
verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Creating LCMV spatial filters...") |
|
|
filters = mne.beamformer.make_lcmv( |
|
|
info=raw.info, |
|
|
forward=fwd, |
|
|
data_cov=cov, |
|
|
noise_cov=cov, |
|
|
reg=self.config['reg'], |
|
|
pick_ori='max-power', |
|
|
weight_norm='unit-noise-gain', |
|
|
reduce_rank=True, |
|
|
rank='info', |
|
|
verbose=True |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Applying LCMV filters to continuous data...") |
|
|
stc = mne.beamformer.apply_lcmv_raw(raw=raw, filters=filters) |
|
|
|
|
|
|
|
|
print("Saving STC in H5 format (required for complex data)...") |
|
|
stc_file = self.subject_output / 'source_estimate_LCMV.h5' |
|
|
stc.save(stc_file, ftype='h5', overwrite=True) |
|
|
print(f"β STC saved successfully in H5 format: {stc_file}") |
|
|
|
|
|
print(f"β LCMV complete: {stc.data.shape} (sources x timepoints)") |
|
|
print(f"β STC file saved as: {stc_file}") |
|
|
|
|
|
|
|
|
print("\n=== Saving source space information ===") |
|
|
|
|
|
|
|
|
vertices = stc.vertices[0] |
|
|
|
|
|
|
|
|
active_indices = src[0]['vertno'] |
|
|
|
|
|
|
|
|
|
|
|
src_points_m = src[0]['rr'][vertices] |
|
|
|
|
|
src_points_mm = src_points_m * 1000 |
|
|
|
|
|
|
|
|
np.save(self.subject_output / 'source_space_points_mm.npy', src_points_mm) |
|
|
|
|
|
|
|
|
print(f"STC data shape: {stc.data.shape}") |
|
|
print(f"Source points shape: {src_points_mm.shape}") |
|
|
|
|
|
if src_points_mm.shape[0] != stc.data.shape[0]: |
|
|
print(f"WARNING: Shape mismatch detected!") |
|
|
n_sources = min(src_points_mm.shape[0], stc.data.shape[0]) |
|
|
src_points_mm = src_points_mm[:n_sources] |
|
|
print(f"Using first {n_sources} source points to match STC data") |
|
|
|
|
|
print(f"β Source space points saved: {src_points_mm.shape} points") |
|
|
print(f" Matches STC data shape: {stc.data.shape[0]} sources") |
|
|
|
|
|
|
|
|
print("\n=== Saving debug info and metadata ===") |
|
|
|
|
|
|
|
|
debug_info = { |
|
|
'src_vertno': active_indices.tolist(), |
|
|
'stc_vertices': vertices.tolist(), |
|
|
'src_np': src[0]['np'], |
|
|
'n_active_sources': len(active_indices), |
|
|
'n_stc_vertices': len(vertices), |
|
|
'coregistration_error_mm': { |
|
|
'mean': mean_err if 'mean_err' in locals() else None, |
|
|
'median': median_err if 'median_err' in locals() else None, |
|
|
'max': max_err if 'max_err' in locals() else None |
|
|
} |
|
|
} |
|
|
with open(self.subject_output / 'debug_source_info.pkl', 'wb') as f: |
|
|
pickle.dump(debug_info, f) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'stc_shape': stc.data.shape, |
|
|
'n_source_points': len(vertices), |
|
|
'source_space_indices': vertices.tolist(), |
|
|
'sfreq': sfreq, |
|
|
'duration_min': duration_min, |
|
|
'stc_file': str(stc_file), |
|
|
'src_file': str(src_file), |
|
|
'subject_output': str(self.subject_output), |
|
|
'global_subjects_dir': str(self.global_subjects_dir), |
|
|
'enhanced_coregistration': True, |
|
|
'coordinate_normalization': 'mean_centered', |
|
|
'fiducials': { |
|
|
'FidNz': ch_pos['FidNz'].tolist(), |
|
|
'FidT9': ch_pos['FidT9'].tolist(), |
|
|
'FidT10': ch_pos['FidT10'].tolist() |
|
|
} |
|
|
} |
|
|
with open(self.subject_output / 'computation_metadata.pkl', 'wb') as f: |
|
|
pickle.dump(metadata, f) |
|
|
|
|
|
print(f"β Enhanced computation complete and metadata saved") |
|
|
print(f"\nπ ENHANCED LCMV SOURCE ESTIMATION COMPLETE!") |
|
|
print(f" - Enhanced coregistration with error checking") |
|
|
print(f" - Proper coordinate normalization") |
|
|
print(f" - All original outputs maintained") |
|
|
print(f" - Results saved to: {self.subject_output}") |
|
|
|
|
|
return metadata |
|
|
|
|
|
def extract_difumo_time_courses(self, stc, src, config, subject_output): |
|
|
"""Extract weighted time courses from DiFuMo atlas.""" |
|
|
print("\n=== DiFuMo Processing ===") |
|
|
atlas = datasets.fetch_atlas_difumo( |
|
|
dimension=config['dimension'], |
|
|
resolution_mm=config['resolution_mm'] |
|
|
) |
|
|
atlas_img = nib.load(atlas.maps) |
|
|
atlas_shape = atlas_img.shape |
|
|
n_components = atlas_shape[3] |
|
|
|
|
|
|
|
|
vertices = stc.vertices[0] |
|
|
src_rr = src[0]['rr'][vertices] * 1000 |
|
|
|
|
|
|
|
|
try: |
|
|
trans = src[0]['mri_ras_t']['trans'] |
|
|
except KeyError: |
|
|
raise ValueError("Source space missing 'mri_ras_t' transform. Ensure it's a proper volume source space.") |
|
|
|
|
|
mni_coords = image.coord_transform(src_rr[:, 0], src_rr[:, 1], src_rr[:, 2], trans) |
|
|
src_coords_mni = np.array(mni_coords).T |
|
|
|
|
|
|
|
|
homog = np.column_stack([src_coords_mni, np.ones(len(src_coords_mni))]) |
|
|
vox_coords = (np.linalg.inv(atlas_img.affine) @ homog.T).T[:, :3] |
|
|
vox_coords = np.round(vox_coords).astype(int) |
|
|
|
|
|
|
|
|
valid_mask = ( |
|
|
(vox_coords >= 0).all(axis=1) & |
|
|
(vox_coords[:, 0] < atlas_shape[0]) & |
|
|
(vox_coords[:, 1] < atlas_shape[1]) & |
|
|
(vox_coords[:, 2] < atlas_shape[2]) |
|
|
) |
|
|
valid_indices = np.where(valid_mask)[0] |
|
|
valid_voxels = vox_coords[valid_mask] |
|
|
|
|
|
print(f"Using {len(valid_indices)}/{len(vertices)} sources within atlas bounds") |
|
|
|
|
|
|
|
|
time_courses = [] |
|
|
component_info = [] |
|
|
threshold = 1e-6 |
|
|
|
|
|
|
|
|
if np.iscomplexobj(stc.data): |
|
|
print("β οΈ STC is complex β taking absolute value for DiFuMo") |
|
|
stc.data = np.abs(stc.data) |
|
|
|
|
|
for comp_idx in range(n_components): |
|
|
if comp_idx % 100 == 0: |
|
|
print(f"Processing component {comp_idx + 1}/{n_components}") |
|
|
|
|
|
try: |
|
|
comp_map = atlas_img.slicer[..., comp_idx].get_fdata() |
|
|
weights, stc_indices = [], [] |
|
|
|
|
|
for i, (x, y, z) in enumerate(valid_voxels): |
|
|
prob = comp_map[x, y, z] |
|
|
if prob > threshold: |
|
|
weights.append(prob) |
|
|
stc_indices.append(valid_indices[i]) |
|
|
|
|
|
if weights: |
|
|
weights = np.array(weights) |
|
|
weights /= weights.sum() |
|
|
tc = np.average(stc.data[stc_indices], axis=0, weights=weights) |
|
|
info = { |
|
|
'component': comp_idx, |
|
|
'n_sources': len(stc_indices), |
|
|
'max_weight': weights.max(), |
|
|
'mean_weight': weights.mean() |
|
|
} |
|
|
else: |
|
|
tc = np.zeros(stc.data.shape[1]) |
|
|
info = { |
|
|
'component': comp_idx, |
|
|
'n_sources': 0, |
|
|
'max_weight': 0.0, |
|
|
'mean_weight': 0.0 |
|
|
} |
|
|
|
|
|
time_courses.append(tc) |
|
|
component_info.append(info) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in component {comp_idx}: {e}") |
|
|
time_courses.append(np.zeros(stc.data.shape[1])) |
|
|
component_info.append({ |
|
|
'component': comp_idx, 'n_sources': 0, 'max_weight': 0.0, 'mean_weight': 0.0 |
|
|
}) |
|
|
|
|
|
|
|
|
valid_comps = sum(1 for info in component_info if info['n_sources'] > 0) |
|
|
print(f"β
{valid_comps}/{n_components} components have at least one source") |
|
|
|
|
|
|
|
|
subject_output = Path(subject_output) |
|
|
np.save(subject_output / 'difumo_time_courses.npy', np.array(time_courses)) |
|
|
pd.DataFrame(component_info).to_csv(subject_output / 'difumo_component_info.csv', index=False) |
|
|
print(f"πΎ Saved to: {subject_output}") |
|
|
|
|
|
return np.array(time_courses), component_info |
|
|
|
|
|
def run_difumo_extraction(self, difumo_config=None): |
|
|
"""Run DiFuMo time course extraction on existing data.""" |
|
|
if difumo_config is None: |
|
|
difumo_config = { |
|
|
'dimension': 512, |
|
|
'resolution_mm': 2 |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
subject_output = self.subject_output |
|
|
stc_base_name = "source_estimate_LCMV" |
|
|
|
|
|
|
|
|
stc_file = None |
|
|
for suffix in ['-vl.stc', '.stc', '.h5']: |
|
|
candidate = subject_output / f"{stc_base_name}{suffix}" |
|
|
if candidate.exists(): |
|
|
stc_file = candidate |
|
|
break |
|
|
if not stc_file: |
|
|
raise FileNotFoundError(f"STC file not found in {subject_output}") |
|
|
|
|
|
|
|
|
print(f"π Loading STC: {stc_file}") |
|
|
stc = mne.read_source_estimate(stc_file) |
|
|
print(f"Loaded STC: {stc.data.shape} (sources Γ time)") |
|
|
|
|
|
|
|
|
src_file = self.global_subjects_dir / "fsaverage-vol-5mm-src.fif" |
|
|
|
|
|
print(f"π Loading source space: {src_file}") |
|
|
if not src_file.exists(): |
|
|
raise FileNotFoundError(f"Source space not found: {src_file}") |
|
|
src = mne.read_source_spaces(src_file) |
|
|
print(f"Loaded source space with {len(src[0]['vertno'])} active sources") |
|
|
|
|
|
|
|
|
time_courses, component_info = self.extract_difumo_time_courses( |
|
|
stc=stc, |
|
|
src=src, |
|
|
config=difumo_config, |
|
|
subject_output=subject_output |
|
|
) |
|
|
|
|
|
print("\nπ SUCCESS: DiFuMo time series extraction complete!") |
|
|
print(f"π Output shape: {time_courses.shape} (512 components Γ {time_courses.shape[1]} time points)") |
|
|
print(f"π Details saved in:\n - {subject_output / 'difumo_time_courses.npy'}\n - {subject_output / 'difumo_component_info.csv'}") |
|
|
|
|
|
return time_courses, component_info |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error during DiFuMo extraction: {e}") |
|
|
raise |
|
|
|
|
|
def list_output_files(self): |
|
|
"""List all files in the output folder.""" |
|
|
print(f"\n=== Files in output folder: {self.subject_output} ===") |
|
|
for file in os.listdir(self.subject_output): |
|
|
print(file) |
|
|
return list(os.listdir(self.subject_output)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROJECT_BASE = "/home/jaizor/jaizor/xtra" |
|
|
CROP_BASE_DIR = Path(PROJECT_BASE) / "derivatives/ica" |
|
|
GPS_FILE_PATH = "data/ghw280_from_egig.gpsc" |
|
|
|
|
|
|
|
|
CONFIG_TEMPLATE = { |
|
|
'project_base': PROJECT_BASE, |
|
|
'gpsc_file_path': GPS_FILE_PATH, |
|
|
'reg': 0.01, |
|
|
'n_jobs': -1, |
|
|
'skip_difumo': False |
|
|
} |
|
|
|
|
|
|