JayLacoma's picture
Update app.py
b68834c verified
# app.py
import gradio as gr
import numpy as np
import pandas as pd
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# --- Config ---
CACHE_DIR = Path("trial_cache")
METADATA_FILE = "subject_trials.csv"
# Load metadata (Source of Truth for time)
try:
metadata_df = pd.read_csv(METADATA_FILE)
# Ensure 'target_time' column exists
if 'target_time' not in metadata_df.columns:
raise ValueError(f"Column 'target_time' not found in {METADATA_FILE}. Please check your CSV.")
subjects = sorted(metadata_df["subject"].unique())
print(f"βœ… Loaded metadata from {METADATA_FILE} with {len(subjects)} subjects.")
except FileNotFoundError:
raise FileNotFoundError(f"Error: Could not find '{METADATA_FILE}'.")
def get_trials_for_subject(sbj: str):
trials = metadata_df[metadata_df["subject"] == sbj]["trial"].tolist()
return sorted(trials)
def get_target_time(sbj: str, trial: int):
"""Fetches the corrected time directly from the CSV."""
row = metadata_df[(metadata_df["subject"] == sbj) & (metadata_df["trial"] == trial)]
if row.empty:
raise ValueError(f"No entry found for {sbj} Trial {trial} in CSV.")
return float(row.iloc[0]["target_time"])
def load_trial_data(sbj: str, trial: int):
path = CACHE_DIR / f"{sbj}_trial{trial:02d}.npz"
if not path.exists():
raise FileNotFoundError(f"Trial file not found: {path}")
data = np.load(path, allow_pickle=True)
# Get time exclusively from CSV
target_time = get_target_time(sbj, trial)
return {
'times': data['times'],
'rp_pi': data['rp_pi'],
'phase_l_pi': data['phase_l_pi'],
'phase_r_pi': data['phase_r_pi'],
'candidates': data['candidates'].item(),
'sbj': str(data['sbj']),
'trial_number': int(data['trial_number']),
'target_time': target_time # Only this matters now
}
def plot_full_trial_plotly(results: dict):
times = results['times']
phase_l = results['phase_l_pi']
phase_r = results['phase_r_pi']
rp = results['rp_pi']
event_time = results['target_time']
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
subplot_titles=(
f"{results['sbj']} – Trial {results['trial_number']}: Phases",
"Relative Phase (R βˆ’ L)"
)
)
fig.add_trace(go.Scatter(x=times, y=phase_l, mode='lines', name='Left Phase', line=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Scatter(x=times, y=phase_r, mode='lines', name='Right Phase', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=times, y=rp, mode='lines', name='Relative Phase', line=dict(color='teal')), row=2, col=1)
# 🟒 Green Line: The Corrected Time from CSV
fig.add_vline(x=event_time, line=dict(color='green', width=4),
annotation_text="Corrected", annotation_position="top left",
row=1, col=1)
fig.add_vline(x=event_time, line=dict(color='green', width=4), row=2, col=1)
fig.update_layout(
height=700,
hovermode='x unified',
title_x=0.5,
xaxis2_rangeslider_visible=True,
xaxis2_rangeslider_thickness=0.08
)
fig.update_yaxes(title_text="Phase (Γ—Ο€ rad)", row=1, col=1)
fig.update_yaxes(title_text="RP (Γ—Ο€ rad)", row=2, col=1)
fig.update_xaxes(title_text="Time (s)", row=2, col=1)
return fig
# --- Gradio UI ---
with gr.Blocks(title="Phase Inspector") as demo:
gr.Markdown("## Phase Transition Inspector")
gr.Markdown("🟒 **Green Line**: Time defined in `subject_trials.csv`")
plot_output = gr.Plot()
with gr.Row():
initial_sbj = subjects[0] if subjects else None
initial_trials = get_trials_for_subject(initial_sbj) if initial_sbj else []
initial_trial = initial_trials[0] if initial_trials else None
sbj_dropdown = gr.Dropdown(choices=subjects, label="Subject", value=initial_sbj)
trial_dropdown = gr.Dropdown(choices=initial_trials, label="Trial", value=initial_trial)
load_btn = gr.Button("πŸ” Load Trial")
def on_subject_change(sbj: str):
if not sbj:
return gr.Dropdown(choices=[], value=None)
trials = get_trials_for_subject(sbj)
first_trial = trials[0] if trials else None
return gr.Dropdown(choices=trials, value=first_trial)
def on_load_trial(sbj: str, trial: int):
if not sbj or trial is None:
empty_fig = go.Figure()
empty_fig.update_layout(title="Please select a subject and trial.")
return empty_fig
try:
results = load_trial_data(sbj, trial)
fig = plot_full_trial_plotly(results)
return fig
except Exception as e:
empty_fig = go.Figure()
error_msg = "Error: {}".format(str(e))
empty_fig.update_layout(title=error_msg)
return empty_fig
sbj_dropdown.change(
fn=on_subject_change,
inputs=sbj_dropdown,
outputs=trial_dropdown
)
load_btn.click(
fn=on_load_trial,
inputs=[sbj_dropdown, trial_dropdown],
outputs=plot_output
)
# Auto-load on start
if initial_sbj and initial_trial:
demo.load(
fn=on_load_trial,
inputs=[sbj_dropdown, trial_dropdown],
outputs=plot_output
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())