Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import numpy as np | |
| # Set streamlit configuration with disable XSRF protection | |
| st.config.set_option("server.enableXsrfProtection", False) | |
| st.set_page_config(page_title="Dysphagia Analysis", page_icon="👅") | |
| # Function to plot the EMG signal Coordination Analysis | |
| def emg_plot(event_index, event_plot_name, left_std_ratio, left_delta_t, right_std_ratio, right_delta_t): | |
| """ | |
| Plots a 2D quadrant chart for EMG signal analysis with colored squares indicating the quadrant. | |
| Parameters: | |
| std (float): Standard deviation value of the EMG signal. | |
| delta_t (float): Delta T value of the EMG signal. | |
| """ | |
| # Create a new figure | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| # Determine the quadrant and plot the colored square | |
| if left_std_ratio > 3 and left_delta_t > 0: | |
| # First quadrant | |
| ax.add_patch(plt.Rectangle((2, 2), 6, 6, color='blue', alpha=0.5)) | |
| elif left_std_ratio <= 3 and left_delta_t > 0: | |
| # Second quadrant | |
| ax.add_patch(plt.Rectangle((-8, 2), 6, 6, color='blue', alpha=0.5)) | |
| elif left_std_ratio <= 3 and left_delta_t <= 0: | |
| # Third quadrant | |
| ax.add_patch(plt.Rectangle((-8, -8), 6, 6, color='blue', alpha=0.5)) | |
| elif left_std_ratio > 3 and left_delta_t <= 0: | |
| # Fourth quadrant | |
| ax.add_patch(plt.Rectangle((2, -8), 6, 6, color='blue', alpha=0.5)) | |
| # Determine the quadrant and plot the colored square | |
| if right_std_ratio > 3 and right_delta_t > 0: | |
| # First quadrant | |
| ax.add_patch(plt.Rectangle((1.5, 1.5), 6, 6, color='green', alpha=0.5)) | |
| elif right_std_ratio <= 3 and right_delta_t > 0: | |
| # Second quadrant | |
| ax.add_patch(plt.Rectangle((-8.5, 1.5), 6, 6, color='green', alpha=0.5)) | |
| elif right_std_ratio <= 3 and right_delta_t <= 0: | |
| # Third quadrant | |
| ax.add_patch(plt.Rectangle((-8.5, -8.5), 6, 6, color='green', alpha=0.5)) | |
| elif right_std_ratio > 3 and right_delta_t <= 0: | |
| # Fourth quadrant | |
| ax.add_patch(plt.Rectangle((1.5, -8.5), 6, 6, color='green', alpha=0.5)) | |
| # Add horizontal and vertical lines to create quadrants | |
| plt.axhline(y=0, color='black', linestyle='--') | |
| plt.axvline(x=0, color='black', linestyle='--') | |
| # Add quadrant labels | |
| # Add styled text labels with colored box | |
| def add_styled_text(x, y, text, va='bottom'): | |
| # Create text box style | |
| bbox_props = dict( | |
| boxstyle='round,pad=0.5', | |
| fc='#1f77b4', # 背景顏色(藍色) | |
| ec='#1f77b4', # 邊框顏色(藍色) | |
| alpha=0.7, # 背景透明度 | |
| lw=1.5 # 邊框寬度 | |
| ) | |
| plt.text(x, y, text, | |
| horizontalalignment='center', | |
| verticalalignment=va, | |
| bbox=bbox_props, | |
| color='white', | |
| fontweight='semibold', | |
| fontsize=9) | |
| def add_circle_text(x, y, text, va='bottom'): | |
| # Create text box style | |
| bbox_props = dict( | |
| boxstyle='circle,pad=0.5', | |
| fc='#262626', # 背景顏色(黑色) | |
| ec='#262626', # 邊框顏色(黑色) | |
| alpha=0.7, # 背景透明度 | |
| lw=1.5 # 邊框寬度 | |
| ) | |
| plt.text(x, y, text, | |
| horizontalalignment='center', | |
| verticalalignment=va, | |
| bbox=bbox_props, | |
| color='white', | |
| fontweight='semibold', | |
| fontsize=10) | |
| # Add styled quadrant labels | |
| add_styled_text(5, 0.5, "Exertion + / Coordination -", 'bottom') | |
| add_circle_text(1, 0.5, "4", 'bottom') | |
| add_styled_text(-4, 0.5, "Exertion - / Coordination -", 'bottom') | |
| add_circle_text(-8, 0.5, "2", 'bottom') | |
| add_styled_text(-4, -0.5, "Exertion - / Coordination +", 'top') | |
| add_circle_text(-8, -0.5, "1", 'top') | |
| add_styled_text(5, -0.5, "Exertion + / Coordination +", 'top') | |
| add_circle_text(1, -0.5, "3", 'top') | |
| # Add title and axis labels | |
| plt.title(f'Muscle Coordination Analysis - {event_index}:{event_plot_name}', fontsize=14) | |
| plt.xlabel('Exertion (Std Ratio > 3)', fontsize=12, fontweight='semibold') | |
| plt.ylabel('Coordination (Delta T > 0)', fontsize=12, fontweight='semibold') | |
| # Remove axis numbers and labels | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| # Set plot legend with color | |
| plt.legend(['Left Swallowing Muscle', 'Right Swallowing Muscle'], loc='upper left', fontsize=10) | |
| # Set the limits of the plot | |
| plt.xlim(-10, 10) | |
| plt.ylim(-10, 10) | |
| # Display the plot | |
| st.pyplot(plt.gcf()) | |
| #plt.show() | |
| def main(): | |
| st.image("logo/itri_logo.jpg", width=700) | |
| st.title('👅Dysphagia Analysis - by ITRI BDL') | |
| # Initialize session state variables | |
| if 'emg_data' not in st.session_state: | |
| st.session_state.emg_data = None | |
| if 'time_marker' not in st.session_state: | |
| st.session_state.time_marker = None | |
| if 'analysis_started' not in st.session_state: | |
| st.session_state.analysis_started = False | |
| if 'data_isready' not in st.session_state: | |
| st.session_state.data_isready = False | |
| # File Uploaders | |
| uploaded_file1 = st.file_uploader("Choose the EMG_data CSV file", type="csv") | |
| uploaded_file2 = st.file_uploader("Choose the time_marker CSV file", type="csv") | |
| # Load data when files are uploaded | |
| if uploaded_file1 is not None and uploaded_file2 is not None: | |
| try: | |
| st.session_state.emg_data = pd.read_csv(uploaded_file1, skiprows=[0,1,3,4]) | |
| st.session_state.time_marker = pd.read_csv(uploaded_file2) | |
| st.success("Data loaded successfully!") | |
| st.session_state.data_isready = True | |
| except Exception as e: | |
| st.error(f"Error: {e}") | |
| # Load test data button | |
| if st.button('Load Test Data', type="primary"): | |
| st.session_state.emg_data = pd.read_csv('test-new/0-New_Task-recording-0.csv', skiprows=[0,1,3,4]) | |
| st.session_state.time_marker = pd.read_csv('test-new/time_marker.csv') | |
| st.success("Test data loaded successfully!") | |
| st.session_state.data_isready = True | |
| # Display loaded data | |
| if st.session_state.emg_data is not None: | |
| st.subheader("EMG Data") | |
| st.dataframe(st.session_state.emg_data) | |
| if st.session_state.time_marker is not None: | |
| st.subheader("Time Marker") | |
| st.dataframe(st.session_state.time_marker) | |
| # Analysis button | |
| if st.session_state.data_isready: | |
| st.subheader("Muscle Coordination Analysis") | |
| if st.button('Start Analysis', type="primary"): | |
| st.session_state.analysis_started = True | |
| # Perform analysis if started | |
| if st.session_state.analysis_started: | |
| st.write('Analysis in progress...') | |
| # Reset emg data index with Channels | |
| emg_data = st.session_state.emg_data.set_index('Channels') | |
| # Get signal data from difference of emg_data | |
| signal_left_lateral = emg_data['17'] - emg_data['18'] | |
| signal_left_medial = emg_data['19'] - emg_data['20'] | |
| signal_right_lateral = emg_data['23'] - emg_data['24'] | |
| signal_right_medial = emg_data['21'] - emg_data['22'] | |
| # RMS caculation : Define the moving average window size | |
| N = 25 | |
| # Function to calculate moving RMS | |
| def moving_rms(signal, window_size): | |
| rms = np.sqrt(pd.Series(signal).rolling(window=window_size).mean()**2) | |
| rms.index = signal.index # Ensure the index matches the original signal | |
| return rms | |
| # Calculate moving RMS for each channel | |
| signal_left_lateral_RMS = moving_rms(signal_left_lateral, N) | |
| signal_left_medial_RMS = moving_rms(signal_left_medial, N) | |
| signal_right_lateral_RMS = moving_rms(signal_right_lateral, N) | |
| signal_right_medial_RMS = moving_rms(signal_right_medial, N) | |
| # Time Marker Processing | |
| time_marker = st.session_state.time_marker[['0-New_Task-recording_time(us)', 'description', 'tag']] | |
| time_marker = time_marker.rename(columns={'0-New_Task-recording_time(us)': 'event_time'}) | |
| # Select column value with odd/even index | |
| event_start_times = time_marker.loc[0::2]['event_time'].values.astype(int) | |
| event_end_times = time_marker.loc[1::2]['event_time'].values.astype(int) | |
| event_names = time_marker.loc[0::2]['description'].values | |
| # Get signal basic 10s std | |
| signal_left_lateral_basics_10s_std = signal_left_lateral_RMS.loc[: 10000000].std() | |
| signal_right_lateral_basics_10s_std = signal_right_lateral_RMS.loc[: 10000000].std() | |
| # Analyze event data | |
| event_number = len(event_names) | |
| for i in range(1, 2*event_number, 2): | |
| event_name = event_names[i//2] | |
| event_start_time = event_start_times[i//2] | |
| event_end_time = event_end_times[i//2] | |
| st.write(f"Event {i//2+1}: {event_name}") | |
| st.write(f"Start time: {float(event_start_time)/1000000: .3f} sec, End time: {float(event_end_time)/1000000: .3f} sec") | |
| # Get event signal data with event time duration | |
| mask_LL = (signal_left_lateral_RMS.index >= event_start_time) & (signal_left_lateral_RMS.index <= event_end_time) | |
| event_signal_LL = signal_left_lateral_RMS.iloc[mask_LL] | |
| mask_LM = (signal_left_medial_RMS.index >= event_start_time) & (signal_left_medial_RMS.index <= event_end_time) | |
| event_signal_LM = signal_left_medial_RMS.iloc[mask_LM] | |
| mask_RL = (signal_right_lateral_RMS.index >= event_start_time) & (signal_right_lateral_RMS.index <= event_end_time) | |
| event_signal_RL = signal_right_lateral_RMS.iloc[mask_RL] | |
| mask_RM = (signal_right_medial_RMS.index >= event_start_time) & (signal_right_medial_RMS.index <= event_end_time) | |
| event_signal_RM = signal_right_medial_RMS.iloc[mask_RM] | |
| # Calculate std ratio | |
| left_event_std = event_signal_LL.std() | |
| left_std_ratio = left_event_std / signal_left_lateral_basics_10s_std | |
| right_event_std = event_signal_RL.std() | |
| right_std_ratio = right_event_std / signal_right_lateral_basics_10s_std | |
| st.write(f"left std ratio: {left_std_ratio: .3f}, right std ratio: {right_std_ratio: .3f}") | |
| # Get signal max value index | |
| LL_max_index = event_signal_LL.idxmax() | |
| LM_max_index = event_signal_LM.idxmax() | |
| left_delta_t = LM_max_index - LL_max_index | |
| st.write(f"LM_max_index: {float(LM_max_index)/1000000: .3f}, LL_max_index: {float(LL_max_index)/1000000: .3f}, left delta t: {float(left_delta_t)/1000000: .3f}") | |
| RL_max_index = event_signal_RL.idxmax() | |
| RM_max_index = event_signal_RM.idxmax() | |
| right_delta_t = RM_max_index - RL_max_index | |
| st.write(f"RM_max_index: {float(RM_max_index)/1000000: .3f}, RL_max_index: {float(RL_max_index)/1000000: .3f}, right delta t: {float(right_delta_t)/1000000: .3f}") | |
| # Plot with each event data | |
| emg_plot(i//2+1, event_name, left_std_ratio, left_delta_t, right_std_ratio, right_delta_t) | |
| st.write('Analysis completed!') | |
| if __name__ == '__main__': | |
| main() |