# app.py import gradio as gr import numpy as np import pandas as pd from pathlib import Path import plotly.graph_objects as go from plotly.subplots import make_subplots # --- Config --- CACHE_DIR = Path("trial_cache") METADATA_FILE = "subject_trials.csv" # Load metadata (Source of Truth for time) try: metadata_df = pd.read_csv(METADATA_FILE) # Ensure 'target_time' column exists if 'target_time' not in metadata_df.columns: raise ValueError(f"Column 'target_time' not found in {METADATA_FILE}. Please check your CSV.") subjects = sorted(metadata_df["subject"].unique()) print(f"✅ Loaded metadata from {METADATA_FILE} with {len(subjects)} subjects.") except FileNotFoundError: raise FileNotFoundError(f"Error: Could not find '{METADATA_FILE}'.") def get_trials_for_subject(sbj: str): trials = metadata_df[metadata_df["subject"] == sbj]["trial"].tolist() return sorted(trials) def get_target_time(sbj: str, trial: int): """Fetches the corrected time directly from the CSV.""" row = metadata_df[(metadata_df["subject"] == sbj) & (metadata_df["trial"] == trial)] if row.empty: raise ValueError(f"No entry found for {sbj} Trial {trial} in CSV.") return float(row.iloc[0]["target_time"]) def load_trial_data(sbj: str, trial: int): path = CACHE_DIR / f"{sbj}_trial{trial:02d}.npz" if not path.exists(): raise FileNotFoundError(f"Trial file not found: {path}") data = np.load(path, allow_pickle=True) # Get time exclusively from CSV target_time = get_target_time(sbj, trial) return { 'times': data['times'], 'rp_pi': data['rp_pi'], 'phase_l_pi': data['phase_l_pi'], 'phase_r_pi': data['phase_r_pi'], 'candidates': data['candidates'].item(), 'sbj': str(data['sbj']), 'trial_number': int(data['trial_number']), 'target_time': target_time # Only this matters now } def plot_full_trial_plotly(results: dict): times = results['times'] phase_l = results['phase_l_pi'] phase_r = results['phase_r_pi'] rp = results['rp_pi'] event_time = results['target_time'] fig = make_subplots( rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=( f"{results['sbj']} – Trial {results['trial_number']}: Phases", "Relative Phase (R − L)" ) ) fig.add_trace(go.Scatter(x=times, y=phase_l, mode='lines', name='Left Phase', line=dict(color='blue')), row=1, col=1) fig.add_trace(go.Scatter(x=times, y=phase_r, mode='lines', name='Right Phase', line=dict(color='red')), row=1, col=1) fig.add_trace(go.Scatter(x=times, y=rp, mode='lines', name='Relative Phase', line=dict(color='teal')), row=2, col=1) # 🟢 Green Line: The Corrected Time from CSV fig.add_vline(x=event_time, line=dict(color='green', width=4), annotation_text="Corrected", annotation_position="top left", row=1, col=1) fig.add_vline(x=event_time, line=dict(color='green', width=4), row=2, col=1) fig.update_layout( height=700, hovermode='x unified', title_x=0.5, xaxis2_rangeslider_visible=True, xaxis2_rangeslider_thickness=0.08 ) fig.update_yaxes(title_text="Phase (×π rad)", row=1, col=1) fig.update_yaxes(title_text="RP (×π rad)", row=2, col=1) fig.update_xaxes(title_text="Time (s)", row=2, col=1) return fig # --- Gradio UI --- with gr.Blocks(title="Phase Inspector") as demo: gr.Markdown("## Phase Transition Inspector") gr.Markdown("🟢 **Green Line**: Time defined in `subject_trials.csv`") plot_output = gr.Plot() with gr.Row(): initial_sbj = subjects[0] if subjects else None initial_trials = get_trials_for_subject(initial_sbj) if initial_sbj else [] initial_trial = initial_trials[0] if initial_trials else None sbj_dropdown = gr.Dropdown(choices=subjects, label="Subject", value=initial_sbj) trial_dropdown = gr.Dropdown(choices=initial_trials, label="Trial", value=initial_trial) load_btn = gr.Button("🔍 Load Trial") def on_subject_change(sbj: str): if not sbj: return gr.Dropdown(choices=[], value=None) trials = get_trials_for_subject(sbj) first_trial = trials[0] if trials else None return gr.Dropdown(choices=trials, value=first_trial) def on_load_trial(sbj: str, trial: int): if not sbj or trial is None: empty_fig = go.Figure() empty_fig.update_layout(title="Please select a subject and trial.") return empty_fig try: results = load_trial_data(sbj, trial) fig = plot_full_trial_plotly(results) return fig except Exception as e: empty_fig = go.Figure() error_msg = "Error: {}".format(str(e)) empty_fig.update_layout(title=error_msg) return empty_fig sbj_dropdown.change( fn=on_subject_change, inputs=sbj_dropdown, outputs=trial_dropdown ) load_btn.click( fn=on_load_trial, inputs=[sbj_dropdown, trial_dropdown], outputs=plot_output ) # Auto-load on start if initial_sbj and initial_trial: demo.load( fn=on_load_trial, inputs=[sbj_dropdown, trial_dropdown], outputs=plot_output ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft())