AphasiaPred / fc_visualization.py
SreekarB's picture
Upload 7 files
50c6714 verified
"""
FC Matrix Visualization Module.
This module provides functionality for visualizing Functional Connectivity matrices
independently from the prediction pipeline.
"""
import numpy as np
# Configure matplotlib for headless environment
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend
import matplotlib.pyplot as plt
from pathlib import Path
import argparse
import os
import nibabel as nib
try:
from nilearn import input_data, connectome
from nilearn.image import load_img
from nilearn import datasets
NILEARN_AVAILABLE = True
except ImportError:
NILEARN_AVAILABLE = False
print("Warning: nilearn not available. Direct fMRI processing disabled.")
from config import PREPROCESS_CONFIG
# Import shared utility function
from visualization import vector_to_matrix
class FCVisualizer:
"""Class for visualizing FC matrices."""
def __init__(self, cmap='RdBu_r', vmin=-1, vmax=1):
"""
Initialize FCVisualizer with display parameters.
Args:
cmap: Colormap to use for FC matrices
vmin: Minimum value for color scaling
vmax: Maximum value for color scaling
"""
self.cmap = cmap
self.vmin = vmin
self.vmax = vmax
def plot_single_matrix(self, matrix, title="FC Matrix", ax=None, fig=None):
"""
Plot a single FC matrix.
Args:
matrix: 2D numpy array containing FC matrix
title: Title for the plot
ax: Matplotlib axis to plot on (optional)
fig: Matplotlib figure (optional)
Returns:
fig, ax: The figure and axis objects
"""
if ax is None:
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(matrix, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax)
ax.set_title(title)
plt.colorbar(im, ax=ax)
return fig, ax
def plot_matrix_comparison(self, matrices, titles=None, figsize=None):
"""
Plot multiple FC matrices for comparison.
Args:
matrices: List of 2D numpy arrays containing FC matrices
titles: List of titles for each matrix (optional)
figsize: Custom figure size (optional)
Returns:
fig: The figure object
"""
n_matrices = len(matrices)
if figsize is None:
figsize = (5*n_matrices, 5)
if titles is None:
titles = [f"FC Matrix {i+1}" for i in range(n_matrices)]
fig, axes = plt.subplots(1, n_matrices, figsize=figsize)
# Handle single matrix case
if n_matrices == 1:
axes = [axes]
for i, (matrix, title) in enumerate(zip(matrices, titles)):
im = axes[i].imshow(matrix, cmap=self.cmap, vmin=self.vmin, vmax=self.vmax)
axes[i].set_title(title)
plt.colorbar(im, ax=axes[i])
plt.tight_layout()
return fig
def load_and_visualize_npy(self, file_path):
"""
Load and visualize an FC matrix from a .npy file.
Args:
file_path: Path to the .npy file
Returns:
fig: The figure object containing the visualization
"""
# Load the matrix
data = np.load(file_path)
# Check if it's an upper triangle or full matrix
if len(data.shape) == 1:
# Convert upper triangular to full matrix
matrix = self._triu_to_matrix(data)
else:
matrix = data
# Plot the matrix
filename = os.path.basename(file_path)
title = f"FC Matrix: {filename}"
fig, _ = self.plot_single_matrix(matrix, title=title)
return fig
def _triu_to_matrix(self, triu_values, fisher_z=True):
"""
Convert upper triangular values to a full FC matrix.
Args:
triu_values: 1D array of upper triangular values
fisher_z: Whether values are Fisher z-transformed
Returns:
full_matrix: 2D symmetric matrix
"""
# Use shared implementation from visualization.py
return vector_to_matrix(triu_values)
def process_and_visualize_fmri(self, fmri_file):
"""
Process an fMRI file and visualize its FC matrix.
Args:
fmri_file: Path to the fMRI .nii or .nii.gz file
Returns:
fig: The figure object containing the visualization,
or None if processing fails
"""
if not NILEARN_AVAILABLE:
print("Error: nilearn is required for fMRI processing")
return None
try:
# Extract FC matrix (upper triangular values)
fc_triu = self._process_single_fmri(fmri_file)
# Convert to full matrix
fc_matrix = self._triu_to_matrix(fc_triu)
# Plot the matrix
filename = os.path.basename(fmri_file)
title = f"FC Matrix: {filename}"
fig, _ = self.plot_single_matrix(fc_matrix, title=title)
return fig
except Exception as e:
print(f"Error processing fMRI file: {e}")
return None
def _process_single_fmri(self, fmri_file):
"""
Process a single fMRI file to FC matrix.
Args:
fmri_file: Path to the fMRI .nii or .nii.gz file
Returns:
fc_triu: 1D array of upper triangular values (Fisher z-transformed)
"""
print(f"Processing fMRI file: {fmri_file}")
# Use Power 264 atlas
power = datasets.fetch_coords_power_2011()
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
# Create masker
masker = input_data.NiftiSpheresMasker(
coords,
radius=PREPROCESS_CONFIG['radius'],
standardize=True,
memory='nilearn_cache',
memory_level=1,
verbose=0,
detrend=True,
low_pass=PREPROCESS_CONFIG['low_pass'],
high_pass=PREPROCESS_CONFIG['high_pass'],
t_r=PREPROCESS_CONFIG['t_r']
)
# Load and process fMRI
print(f"Loading NIfTI file...")
fmri_img = load_img(fmri_file)
print(f"NIfTI file loaded, shape: {fmri_img.shape}")
# Transform to time series
print(f"Extracting time series...")
time_series = masker.fit_transform(fmri_img)
print(f"Time series extracted, shape: {time_series.shape}")
# Compute FC matrix
print(f"Computing FC matrix...")
correlation_measure = connectome.ConnectivityMeasure(
kind='correlation',
vectorize=False,
discard_diagonal=False
)
fc_matrix = correlation_measure.fit_transform([time_series])[0]
print(f"FC matrix computed, shape: {fc_matrix.shape}")
# Get upper triangular part
triu_indices = np.triu_indices_from(fc_matrix, k=1)
fc_triu = fc_matrix[triu_indices]
# Fisher z-transform
fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99)) # Clip to avoid infinite values
print(f"Processing complete. FC features shape: {fc_triu.shape}")
return fc_triu
def create_synthetic_fc_matrix(seed=None):
"""
Create a synthetic FC matrix for demonstration purposes.
Args:
seed: Random seed for reproducibility
Returns:
matrix: 2D symmetric matrix representing FC
"""
if seed is not None:
np.random.seed(seed)
# Number of ROIs (Power atlas has 264)
n_rois = 264
# Create random correlation matrix
# Method: generate random normal values, create outer product, normalize
random_vectors = np.random.randn(n_rois, 50) # 50 random features
matrix = np.corrcoef(random_vectors)
# Ensure it's in the range [-1, 1] with 1s on diagonal
np.fill_diagonal(matrix, 1.0)
return matrix
def main():
"""Command-line interface for FC matrix visualization."""
parser = argparse.ArgumentParser(description='Visualize FC matrices')
parser.add_argument('--input', type=str, help='Input file (fMRI .nii/.nii.gz or .npy FC matrix)')
parser.add_argument('--output', type=str, help='Output image file (PNG/JPG/PDF)')
parser.add_argument('--cmap', type=str, default='RdBu_r', help='Colormap (default: RdBu_r)')
parser.add_argument('--vmin', type=float, default=-1, help='Minimum value for colormap')
parser.add_argument('--vmax', type=float, default=1, help='Maximum value for colormap')
parser.add_argument('--synthetic', action='store_true', help='Generate a synthetic FC matrix')
parser.add_argument('--seed', type=int, default=42, help='Random seed for synthetic data')
args = parser.parse_args()
# Create visualizer
visualizer = FCVisualizer(cmap=args.cmap, vmin=args.vmin, vmax=args.vmax)
# Determine figure to create
fig = None
if args.synthetic:
# Create synthetic FC matrix
matrix = create_synthetic_fc_matrix(seed=args.seed)
fig, _ = visualizer.plot_single_matrix(matrix, title="Synthetic FC Matrix")
elif args.input:
input_path = Path(args.input)
if not input_path.exists():
print(f"Error: Input file not found: {args.input}")
return
# Process based on file type
if input_path.suffix == '.npy':
# It's a numpy file with FC matrix
fig = visualizer.load_and_visualize_npy(input_path)
elif input_path.suffix == '.nii' or input_path.suffix == '.gz':
# It's an fMRI file
if not NILEARN_AVAILABLE:
print("Error: nilearn is required for processing fMRI files")
return
fig = visualizer.process_and_visualize_fmri(input_path)
else:
print(f"Error: Unsupported file format: {input_path.suffix}")
print("Supported formats: .npy (FC matrix), .nii/.nii.gz (fMRI)")
return
else:
# No input or synthetic flag - show demo
print("No input file or --synthetic flag provided. Generating a demo matrix.")
matrix = create_synthetic_fc_matrix(seed=args.seed)
fig, _ = visualizer.plot_single_matrix(matrix, title="Demo FC Matrix")
# Save or display the figure
if fig is not None:
if args.output:
fig.savefig(args.output, dpi=300, bbox_inches='tight')
print(f"Visualization saved to {args.output}")
else:
plt.show()
print("Visualization displayed. Close the window to exit.")
else:
print("Error: Failed to create visualization")
if __name__ == "__main__":
main()