Spaces:
Sleeping
Sleeping
File size: 5,584 Bytes
70e2016 fe2a734 a6a293f fe2a734 70e2016 fe2a734 70e2016 fe2a734 70e2016 8d318db fe2a734 70e2016 02b8f1b fe2a734 70e2016 fe2a734 70e2016 f63c041 70e2016 8d318db fe2a734 70e2016 fe2a734 70e2016 fe2a734 02b8f1b 70e2016 047ee29 70e2016 f63c041 047ee29 fe2a734 f63c041 70e2016 fe2a734 70e2016 fe2a734 a11decd fe2a734 a11decd fe2a734 a11decd 70e2016 fe2a734 70e2016 f63c041 a11decd 70e2016 fe2a734 70e2016 b68834c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | # app.py
import gradio as gr
import numpy as np
import pandas as pd
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# --- Config ---
CACHE_DIR = Path("trial_cache")
METADATA_FILE = "subject_trials.csv"
# Load metadata (Source of Truth for time)
try:
metadata_df = pd.read_csv(METADATA_FILE)
# Ensure 'target_time' column exists
if 'target_time' not in metadata_df.columns:
raise ValueError(f"Column 'target_time' not found in {METADATA_FILE}. Please check your CSV.")
subjects = sorted(metadata_df["subject"].unique())
print(f"β
Loaded metadata from {METADATA_FILE} with {len(subjects)} subjects.")
except FileNotFoundError:
raise FileNotFoundError(f"Error: Could not find '{METADATA_FILE}'.")
def get_trials_for_subject(sbj: str):
trials = metadata_df[metadata_df["subject"] == sbj]["trial"].tolist()
return sorted(trials)
def get_target_time(sbj: str, trial: int):
"""Fetches the corrected time directly from the CSV."""
row = metadata_df[(metadata_df["subject"] == sbj) & (metadata_df["trial"] == trial)]
if row.empty:
raise ValueError(f"No entry found for {sbj} Trial {trial} in CSV.")
return float(row.iloc[0]["target_time"])
def load_trial_data(sbj: str, trial: int):
path = CACHE_DIR / f"{sbj}_trial{trial:02d}.npz"
if not path.exists():
raise FileNotFoundError(f"Trial file not found: {path}")
data = np.load(path, allow_pickle=True)
# Get time exclusively from CSV
target_time = get_target_time(sbj, trial)
return {
'times': data['times'],
'rp_pi': data['rp_pi'],
'phase_l_pi': data['phase_l_pi'],
'phase_r_pi': data['phase_r_pi'],
'candidates': data['candidates'].item(),
'sbj': str(data['sbj']),
'trial_number': int(data['trial_number']),
'target_time': target_time # Only this matters now
}
def plot_full_trial_plotly(results: dict):
times = results['times']
phase_l = results['phase_l_pi']
phase_r = results['phase_r_pi']
rp = results['rp_pi']
event_time = results['target_time']
fig = make_subplots(
rows=2, cols=1,
shared_xaxes=True,
vertical_spacing=0.05,
subplot_titles=(
f"{results['sbj']} β Trial {results['trial_number']}: Phases",
"Relative Phase (R β L)"
)
)
fig.add_trace(go.Scatter(x=times, y=phase_l, mode='lines', name='Left Phase', line=dict(color='blue')), row=1, col=1)
fig.add_trace(go.Scatter(x=times, y=phase_r, mode='lines', name='Right Phase', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=times, y=rp, mode='lines', name='Relative Phase', line=dict(color='teal')), row=2, col=1)
# π’ Green Line: The Corrected Time from CSV
fig.add_vline(x=event_time, line=dict(color='green', width=4),
annotation_text="Corrected", annotation_position="top left",
row=1, col=1)
fig.add_vline(x=event_time, line=dict(color='green', width=4), row=2, col=1)
fig.update_layout(
height=700,
hovermode='x unified',
title_x=0.5,
xaxis2_rangeslider_visible=True,
xaxis2_rangeslider_thickness=0.08
)
fig.update_yaxes(title_text="Phase (ΓΟ rad)", row=1, col=1)
fig.update_yaxes(title_text="RP (ΓΟ rad)", row=2, col=1)
fig.update_xaxes(title_text="Time (s)", row=2, col=1)
return fig
# --- Gradio UI ---
with gr.Blocks(title="Phase Inspector") as demo:
gr.Markdown("## Phase Transition Inspector")
gr.Markdown("π’ **Green Line**: Time defined in `subject_trials.csv`")
plot_output = gr.Plot()
with gr.Row():
initial_sbj = subjects[0] if subjects else None
initial_trials = get_trials_for_subject(initial_sbj) if initial_sbj else []
initial_trial = initial_trials[0] if initial_trials else None
sbj_dropdown = gr.Dropdown(choices=subjects, label="Subject", value=initial_sbj)
trial_dropdown = gr.Dropdown(choices=initial_trials, label="Trial", value=initial_trial)
load_btn = gr.Button("π Load Trial")
def on_subject_change(sbj: str):
if not sbj:
return gr.Dropdown(choices=[], value=None)
trials = get_trials_for_subject(sbj)
first_trial = trials[0] if trials else None
return gr.Dropdown(choices=trials, value=first_trial)
def on_load_trial(sbj: str, trial: int):
if not sbj or trial is None:
empty_fig = go.Figure()
empty_fig.update_layout(title="Please select a subject and trial.")
return empty_fig
try:
results = load_trial_data(sbj, trial)
fig = plot_full_trial_plotly(results)
return fig
except Exception as e:
empty_fig = go.Figure()
error_msg = "Error: {}".format(str(e))
empty_fig.update_layout(title=error_msg)
return empty_fig
sbj_dropdown.change(
fn=on_subject_change,
inputs=sbj_dropdown,
outputs=trial_dropdown
)
load_btn.click(
fn=on_load_trial,
inputs=[sbj_dropdown, trial_dropdown],
outputs=plot_output
)
# Auto-load on start
if initial_sbj and initial_trial:
demo.load(
fn=on_load_trial,
inputs=[sbj_dropdown, trial_dropdown],
outputs=plot_output
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft()) |