Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| # --- Config --- | |
| CACHE_DIR = Path("trial_cache") | |
| METADATA_FILE = "subject_trials.csv" | |
| # Load metadata (Source of Truth for time) | |
| try: | |
| metadata_df = pd.read_csv(METADATA_FILE) | |
| # Ensure 'target_time' column exists | |
| if 'target_time' not in metadata_df.columns: | |
| raise ValueError(f"Column 'target_time' not found in {METADATA_FILE}. Please check your CSV.") | |
| subjects = sorted(metadata_df["subject"].unique()) | |
| print(f"β Loaded metadata from {METADATA_FILE} with {len(subjects)} subjects.") | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Error: Could not find '{METADATA_FILE}'.") | |
| def get_trials_for_subject(sbj: str): | |
| trials = metadata_df[metadata_df["subject"] == sbj]["trial"].tolist() | |
| return sorted(trials) | |
| def get_target_time(sbj: str, trial: int): | |
| """Fetches the corrected time directly from the CSV.""" | |
| row = metadata_df[(metadata_df["subject"] == sbj) & (metadata_df["trial"] == trial)] | |
| if row.empty: | |
| raise ValueError(f"No entry found for {sbj} Trial {trial} in CSV.") | |
| return float(row.iloc[0]["target_time"]) | |
| def load_trial_data(sbj: str, trial: int): | |
| path = CACHE_DIR / f"{sbj}_trial{trial:02d}.npz" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Trial file not found: {path}") | |
| data = np.load(path, allow_pickle=True) | |
| # Get time exclusively from CSV | |
| target_time = get_target_time(sbj, trial) | |
| return { | |
| 'times': data['times'], | |
| 'rp_pi': data['rp_pi'], | |
| 'phase_l_pi': data['phase_l_pi'], | |
| 'phase_r_pi': data['phase_r_pi'], | |
| 'candidates': data['candidates'].item(), | |
| 'sbj': str(data['sbj']), | |
| 'trial_number': int(data['trial_number']), | |
| 'target_time': target_time # Only this matters now | |
| } | |
| def plot_full_trial_plotly(results: dict): | |
| times = results['times'] | |
| phase_l = results['phase_l_pi'] | |
| phase_r = results['phase_r_pi'] | |
| rp = results['rp_pi'] | |
| event_time = results['target_time'] | |
| fig = make_subplots( | |
| rows=2, cols=1, | |
| shared_xaxes=True, | |
| vertical_spacing=0.05, | |
| subplot_titles=( | |
| f"{results['sbj']} β Trial {results['trial_number']}: Phases", | |
| "Relative Phase (R β L)" | |
| ) | |
| ) | |
| fig.add_trace(go.Scatter(x=times, y=phase_l, mode='lines', name='Left Phase', line=dict(color='blue')), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=times, y=phase_r, mode='lines', name='Right Phase', line=dict(color='red')), row=1, col=1) | |
| fig.add_trace(go.Scatter(x=times, y=rp, mode='lines', name='Relative Phase', line=dict(color='teal')), row=2, col=1) | |
| # π’ Green Line: The Corrected Time from CSV | |
| fig.add_vline(x=event_time, line=dict(color='green', width=4), | |
| annotation_text="Corrected", annotation_position="top left", | |
| row=1, col=1) | |
| fig.add_vline(x=event_time, line=dict(color='green', width=4), row=2, col=1) | |
| fig.update_layout( | |
| height=700, | |
| hovermode='x unified', | |
| title_x=0.5, | |
| xaxis2_rangeslider_visible=True, | |
| xaxis2_rangeslider_thickness=0.08 | |
| ) | |
| fig.update_yaxes(title_text="Phase (ΓΟ rad)", row=1, col=1) | |
| fig.update_yaxes(title_text="RP (ΓΟ rad)", row=2, col=1) | |
| fig.update_xaxes(title_text="Time (s)", row=2, col=1) | |
| return fig | |
| # --- Gradio UI --- | |
| with gr.Blocks(title="Phase Inspector") as demo: | |
| gr.Markdown("## Phase Transition Inspector") | |
| gr.Markdown("π’ **Green Line**: Time defined in `subject_trials.csv`") | |
| plot_output = gr.Plot() | |
| with gr.Row(): | |
| initial_sbj = subjects[0] if subjects else None | |
| initial_trials = get_trials_for_subject(initial_sbj) if initial_sbj else [] | |
| initial_trial = initial_trials[0] if initial_trials else None | |
| sbj_dropdown = gr.Dropdown(choices=subjects, label="Subject", value=initial_sbj) | |
| trial_dropdown = gr.Dropdown(choices=initial_trials, label="Trial", value=initial_trial) | |
| load_btn = gr.Button("π Load Trial") | |
| def on_subject_change(sbj: str): | |
| if not sbj: | |
| return gr.Dropdown(choices=[], value=None) | |
| trials = get_trials_for_subject(sbj) | |
| first_trial = trials[0] if trials else None | |
| return gr.Dropdown(choices=trials, value=first_trial) | |
| def on_load_trial(sbj: str, trial: int): | |
| if not sbj or trial is None: | |
| empty_fig = go.Figure() | |
| empty_fig.update_layout(title="Please select a subject and trial.") | |
| return empty_fig | |
| try: | |
| results = load_trial_data(sbj, trial) | |
| fig = plot_full_trial_plotly(results) | |
| return fig | |
| except Exception as e: | |
| empty_fig = go.Figure() | |
| error_msg = "Error: {}".format(str(e)) | |
| empty_fig.update_layout(title=error_msg) | |
| return empty_fig | |
| sbj_dropdown.change( | |
| fn=on_subject_change, | |
| inputs=sbj_dropdown, | |
| outputs=trial_dropdown | |
| ) | |
| load_btn.click( | |
| fn=on_load_trial, | |
| inputs=[sbj_dropdown, trial_dropdown], | |
| outputs=plot_output | |
| ) | |
| # Auto-load on start | |
| if initial_sbj and initial_trial: | |
| demo.load( | |
| fn=on_load_trial, | |
| inputs=[sbj_dropdown, trial_dropdown], | |
| outputs=plot_output | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft()) |