brain-analysis / src /streamlit_app.py
fadzwan's picture
Update src/streamlit_app.py
a7252ea verified
import os
import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
from mne.channels import make_dig_montage
from matplotlib.animation import FuncAnimation
import requests
import streamlit as st
import tempfile
from scipy.signal import welch
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["STREAMLIT_HOME"] = "/tmp/streamlit"
# Set page config
st.set_page_config(page_title="Muse EEG Topomap Viewer", layout="centered")
def plot_eeg_topomap_muse_from_csv(csv_url, save_path_animation=None, save_directory=None,
show_names=False, start_time=0.05, end_time=1, step_size=0.1):
response = requests.get(csv_url)
if response.status_code != 200:
raise RuntimeError(f"Failed to download CSV from {csv_url}")
data = pd.read_csv(io.StringIO(response.text))
st.write(f"Loaded data with shape: {data.shape}")
muse_channels = ['AF7', 'AF8', 'TP9', 'TP10']
if not all(ch in data.columns for ch in muse_channels):
raise ValueError(f"The dataset must contain the following channels: {muse_channels}")
eeg_data = data[muse_channels].values.T
muse_positions = {
'AF7': [-0.05, 0.085, 0],
'AF8': [0.05, 0.085, 0],
'TP9': [-0.08, -0.04, 0],
'TP10': [0.08, -0.04, 0]
}
for ch in muse_positions:
muse_positions[ch][0] += np.random.normal(0, 0.0001)
muse_positions[ch][1] += np.random.normal(0, 0.0001)
montage = make_dig_montage(ch_pos=muse_positions, coord_frame='head')
if 'TimeStamp' in data.columns:
timestamps = data['TimeStamp']
elif 'timestamps' in data.columns:
timestamps = data['timestamps']
else:
raise ValueError("CSV must contain a 'TimeStamp' or 'timestamps' column")
time_diffs = np.diff(timestamps)
average_interval = np.mean(time_diffs)
sfreq = 1 / average_interval
info = mne.create_info(ch_names=muse_channels, sfreq=sfreq, ch_types='eeg')
evoked = mne.EvokedArray(eeg_data, info)
evoked.set_montage(montage)
times = np.arange(start_time, end_time, step_size)
if len(times) == 0:
raise ValueError("Not enough time range selected for topomap animation.")
if save_path_animation:
if save_directory:
os.makedirs(save_directory, exist_ok=True)
save_path_animation = os.path.join(save_directory, os.path.basename(save_path_animation))
fig, ax = plt.subplots()
def update(frame):
ax.clear()
evoked.plot_topomap([times[frame]], ch_type='eeg', time_unit='s',
axes=ax, colorbar=False, show=False)
anim = FuncAnimation(fig, update, frames=range(len(times)), interval=200)
anim.save(save_path_animation, writer='pillow')
plt.close(fig)
return save_path_animation
else:
return None
# Streamlit UI
st.title("🧠 Muse EEG Topomap Animation from CSV")
default_url = "https://raw.githubusercontent.com/garenasd945/EEG_muse2/refs/heads/master/dataset/original_data/subjecta-relaxed-1.csv"
csv_url = st.text_input("CSV URL", value=default_url)
generate_button = st.button("Generate Topomap Animation")
if generate_button:
try:
with tempfile.TemporaryDirectory() as tmpdir:
save_path_animation = os.path.join(tmpdir, "eeg_topomap_animation.gif")
gif_path = plot_eeg_topomap_muse_from_csv(
csv_url,
save_path_animation=save_path_animation,
save_directory=tmpdir,
show_names=True
)
if gif_path:
st.image(gif_path, caption="EEG Topomap Animation")
except Exception as e:
st.error(f"❌ Error: {str(e)}")
plot_waveform_button = st.button("Plot TP9 Waveforms")
if plot_waveform_button:
try:
# Download and read CSV
response = requests.get(csv_url)
response.raise_for_status()
data = pd.read_csv(io.StringIO(response.text))
st.write(f"Loaded data with shape: {data.shape}")
# Extract timestamps and TP9
timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
tp10_raw = data['TP9']
# Calculate sampling rate
time_diffs = np.diff(timestamps)
average_interval = np.mean(time_diffs)
sampling_rate = 1 / average_interval
# Create MNE Raw object
info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
raw = mne.io.RawArray(np.array([tp10_raw]), info)
# Apply bandpass filter
raw.filter(1, 50, fir_design='firwin')
tp10_filtered_mne = raw.get_data()[0]
# Plot raw waveform
fig_raw, ax_raw = plt.subplots(figsize=(12, 4))
ax_raw.plot(timestamps, tp10_raw, label='TP9', alpha=0.6, color='blue')
ax_raw.set_title('Raw Waveform of TP9')
ax_raw.set_xlabel('Time (seconds)')
ax_raw.set_ylabel('Amplitude')
ax_raw.legend()
ax_raw.grid(True)
st.pyplot(fig_raw)
# Plot filtered waveform
fig_filtered, ax_filtered = plt.subplots(figsize=(12, 4))
ax_filtered.plot(timestamps, tp10_filtered_mne, label='TP9 (1-50Hz)', linewidth=1.5, color='green')
ax_filtered.set_title('Filtered Waveform of TP9')
ax_filtered.set_xlabel('Time (seconds)')
ax_filtered.set_ylabel('Amplitude')
ax_filtered.legend()
ax_filtered.grid(True)
st.pyplot(fig_filtered)
except Exception as e:
st.error(f"❌ Error: {str(e)}")
spectrogram_button = st.button("Plot TP9 Spectrograms")
if spectrogram_button:
try:
# Download and read CSV
response = requests.get(csv_url)
response.raise_for_status()
data = pd.read_csv(io.StringIO(response.text))
st.write(f"Loaded data with shape: {data.shape}")
# Extract timestamps and TP9
timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
tp10_raw = data['TP9']
# Calculate sampling rate
time_diffs = np.diff(timestamps)
average_interval = np.mean(time_diffs)
sampling_rate = 1 / average_interval
# Create MNE Raw object
info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
raw = mne.io.RawArray(np.array([tp10_raw]), info)
# Apply bandpass filter
raw.filter(1, 50, fir_design='firwin')
tp10_filtered_mne = raw.get_data()[0]
from scipy.signal import spectrogram
# Spectrogram before filtering
frequencies_pre, times_pre, spectrogram_pre = spectrogram(tp10_raw, fs=sampling_rate, nperseg=256)
fig_spec_pre, ax_pre = plt.subplots(figsize=(10, 6))
p1 = ax_pre.pcolormesh(times_pre, frequencies_pre, np.log10(spectrogram_pre + 1e-10), shading='gouraud')
fig_spec_pre.colorbar(p1, ax=ax_pre, label='Log Power Spectral Density')
ax_pre.set_title('Spectrogram of TP9 (Raw)')
ax_pre.set_xlabel('Time (seconds)')
ax_pre.set_ylabel('Frequency (Hz)')
st.pyplot(fig_spec_pre)
# Spectrogram after filtering
frequencies_post, times_post, spectrogram_post = spectrogram(tp10_filtered_mne, fs=sampling_rate, nperseg=256)
fig_spec_post, ax_post = plt.subplots(figsize=(10, 6))
p2 = ax_post.pcolormesh(times_post, frequencies_post, np.log10(spectrogram_post + 1e-10), shading='gouraud')
fig_spec_post.colorbar(p2, ax=ax_post, label='Log Power Spectral Density')
ax_post.set_title('Spectrogram of TP9 (Filtered 1–50Hz)')
ax_post.set_xlabel('Time (seconds)')
ax_post.set_ylabel('Frequency (Hz)')
st.pyplot(fig_spec_post)
except Exception as e:
st.error(f"❌ Error: {str(e)}")
# Button to calculate and plot band powers
if st.button("Show Band Power Chart (TP9)"):
try:
response = requests.get(csv_url)
response.raise_for_status()
data = pd.read_csv(io.StringIO(response.text))
st.write(f"Loaded data with shape: {data.shape}")
# Extract timestamps and TP9
timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
tp10_raw = data['TP9']
# Calculate sampling rate
time_diffs = np.diff(timestamps)
average_interval = np.mean(time_diffs)
sampling_rate = 1 / average_interval
# Create MNE Raw object
info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
raw = mne.io.RawArray(np.array([tp10_raw]), info)
# Apply bandpass filter
raw.filter(1, 50, fir_design='firwin')
tp10_filtered_mne = raw.get_data()[0]
# Define bands
bands = {
"Delta (1–4 Hz)": (1, 4),
"Theta (4–8 Hz)": (4, 8),
"Alpha (8–12 Hz)": (8, 12),
"Beta (12–30 Hz)": (12, 30),
"Gamma (30–50 Hz)": (30, 50)
}
# Bandpower calculation
from scipy.signal import welch
def bandpower(data, sf, band, window_sec=None):
band = np.asarray(band)
low, high = band
nperseg = int(window_sec * sf) if window_sec else None
freqs, psd = welch(data, sf, nperseg=nperseg)
idx_band = np.logical_and(freqs >= low, freqs <= high)
return np.trapz(psd[idx_band], freqs[idx_band])
# Calculate power for each band
powers = [bandpower(tp10_filtered_mne, sampling_rate, b) for b in bands.values()]
# Plot
fig, ax = plt.subplots()
ax.bar(bands.keys(), powers, color='skyblue')
ax.set_title("Band Power - TP9")
ax.set_ylabel("Power")
ax.tick_params(axis='x', rotation=45)
st.pyplot(fig)
except Exception as e:
st.error(f"❌ Error: {str(e)}")
# Button to show EEG animation
if st.button("Show Animated EEG Trace (TP9)"):
try:
response = requests.get(csv_url)
response.raise_for_status()
data = pd.read_csv(io.StringIO(response.text))
st.write(f"Loaded data with shape: {data.shape}")
# Extract timestamps and TP9
timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
tp10_raw = data['TP9']
# Calculate sampling rate
time_diffs = np.diff(timestamps)
average_interval = np.mean(time_diffs)
sampling_rate = 1 / average_interval
# Create MNE Raw object
info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
raw = mne.io.RawArray(np.array([tp10_raw]), info)
# Apply bandpass filter
raw.filter(1, 50, fir_design='firwin')
tp10_filtered_mne = raw.get_data()[0]
# Setup figure and axes
fig, ax = plt.subplots()
line, = ax.plot([], [], lw=2)
ax.set_xlim(0, 5)
ax.set_ylim(np.min(tp10_filtered_mne), np.max(tp10_filtered_mne))
ax.set_title("Animated EEG Trace - TP9")
# Init function
def init():
line.set_data([], [])
return line,
# Animation function
def animate(i):
start = i * int(sampling_rate * 0.1)
end = start + int(sampling_rate * 5)
x = np.arange(start, end) / sampling_rate
y = tp10_filtered_mne[start:end]
line.set_data(x, y)
return line,
# Create animation
ani = FuncAnimation(fig, animate, init_func=init, frames=100, interval=100, blit=True)
# Save to temporary file as GIF
with tempfile.NamedTemporaryFile(delete=False, suffix='.gif') as tmpfile:
ani.save(tmpfile.name, writer='pillow', fps=10)
st.image(tmpfile.name, caption="Animated EEG Trace", use_column_width=True)
plt.close(fig) # Close figure to avoid duplicate rendering
except Exception as e:
st.error(f"❌ Error: {str(e)}")