File size: 5,584 Bytes
70e2016
 
 
 
 
 
 
 
 
 
fe2a734
a6a293f
fe2a734
 
 
 
 
 
 
 
 
 
 
 
70e2016
 
 
 
 
fe2a734
 
 
 
 
 
 
70e2016
 
fe2a734
 
 
 
70e2016
8d318db
fe2a734
 
 
70e2016
 
 
 
 
02b8f1b
fe2a734
70e2016
fe2a734
70e2016
 
f63c041
70e2016
 
 
 
8d318db
fe2a734
70e2016
 
 
 
 
 
fe2a734
70e2016
 
 
 
 
 
 
 
fe2a734
 
 
 
 
02b8f1b
70e2016
 
 
 
 
047ee29
70e2016
 
 
 
 
 
 
f63c041
047ee29
 
fe2a734
f63c041
70e2016
 
 
fe2a734
 
 
 
 
 
70e2016
 
fe2a734
 
 
 
 
 
 
a11decd
fe2a734
 
 
 
 
a11decd
 
 
 
 
 
fe2a734
 
a11decd
70e2016
 
fe2a734
70e2016
 
 
 
 
 
f63c041
a11decd
70e2016
fe2a734
 
 
 
 
 
 
 
70e2016
 
b68834c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# app.py
import gradio as gr
import numpy as np
import pandas as pd
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# --- Config ---
CACHE_DIR = Path("trial_cache")
METADATA_FILE = "subject_trials.csv"

# Load metadata (Source of Truth for time)
try:
    metadata_df = pd.read_csv(METADATA_FILE)
    
    # Ensure 'target_time' column exists
    if 'target_time' not in metadata_df.columns:
        raise ValueError(f"Column 'target_time' not found in {METADATA_FILE}. Please check your CSV.")
        
    subjects = sorted(metadata_df["subject"].unique())
    print(f"βœ… Loaded metadata from {METADATA_FILE} with {len(subjects)} subjects.")
except FileNotFoundError:
    raise FileNotFoundError(f"Error: Could not find '{METADATA_FILE}'.")

def get_trials_for_subject(sbj: str):
    trials = metadata_df[metadata_df["subject"] == sbj]["trial"].tolist()
    return sorted(trials)

def get_target_time(sbj: str, trial: int):
    """Fetches the corrected time directly from the CSV."""
    row = metadata_df[(metadata_df["subject"] == sbj) & (metadata_df["trial"] == trial)]
    if row.empty:
        raise ValueError(f"No entry found for {sbj} Trial {trial} in CSV.")
    return float(row.iloc[0]["target_time"])

def load_trial_data(sbj: str, trial: int):
    path = CACHE_DIR / f"{sbj}_trial{trial:02d}.npz"
    
    if not path.exists():
        raise FileNotFoundError(f"Trial file not found: {path}")
        
    data = np.load(path, allow_pickle=True)
    
    # Get time exclusively from CSV
    target_time = get_target_time(sbj, trial)

    return {
        'times': data['times'],
        'rp_pi': data['rp_pi'],
        'phase_l_pi': data['phase_l_pi'],
        'phase_r_pi': data['phase_r_pi'],
        'candidates': data['candidates'].item(),
        'sbj': str(data['sbj']),
        'trial_number': int(data['trial_number']),
        'target_time': target_time   # Only this matters now
    }

def plot_full_trial_plotly(results: dict):
    times = results['times']
    phase_l = results['phase_l_pi']
    phase_r = results['phase_r_pi']
    rp = results['rp_pi']
    
    event_time = results['target_time']

    fig = make_subplots(
        rows=2, cols=1,
        shared_xaxes=True,
        vertical_spacing=0.05,
        subplot_titles=(
            f"{results['sbj']} – Trial {results['trial_number']}: Phases",
            "Relative Phase (R βˆ’ L)"
        )
    )

    fig.add_trace(go.Scatter(x=times, y=phase_l, mode='lines', name='Left Phase', line=dict(color='blue')), row=1, col=1)
    fig.add_trace(go.Scatter(x=times, y=phase_r, mode='lines', name='Right Phase', line=dict(color='red')), row=1, col=1)
    fig.add_trace(go.Scatter(x=times, y=rp, mode='lines', name='Relative Phase', line=dict(color='teal')), row=2, col=1)

    # 🟒 Green Line: The Corrected Time from CSV
    fig.add_vline(x=event_time, line=dict(color='green', width=4), 
                  annotation_text="Corrected", annotation_position="top left",
                  row=1, col=1)
    fig.add_vline(x=event_time, line=dict(color='green', width=4), row=2, col=1)

    fig.update_layout(
        height=700,
        hovermode='x unified',
        title_x=0.5,
        xaxis2_rangeslider_visible=True,
        xaxis2_rangeslider_thickness=0.08
    )
    fig.update_yaxes(title_text="Phase (Γ—Ο€ rad)", row=1, col=1)
    fig.update_yaxes(title_text="RP (Γ—Ο€ rad)", row=2, col=1)
    fig.update_xaxes(title_text="Time (s)", row=2, col=1)

    return fig

# --- Gradio UI ---
with gr.Blocks(title="Phase Inspector") as demo:
    gr.Markdown("## Phase Transition Inspector")
    gr.Markdown("🟒 **Green Line**: Time defined in `subject_trials.csv`")

    plot_output = gr.Plot()

    with gr.Row():
        initial_sbj = subjects[0] if subjects else None
        initial_trials = get_trials_for_subject(initial_sbj) if initial_sbj else []
        initial_trial = initial_trials[0] if initial_trials else None
        
        sbj_dropdown = gr.Dropdown(choices=subjects, label="Subject", value=initial_sbj)
        trial_dropdown = gr.Dropdown(choices=initial_trials, label="Trial", value=initial_trial)
        load_btn = gr.Button("πŸ” Load Trial")

    def on_subject_change(sbj: str):
        if not sbj:
            return gr.Dropdown(choices=[], value=None)
        trials = get_trials_for_subject(sbj)
        first_trial = trials[0] if trials else None
        return gr.Dropdown(choices=trials, value=first_trial)

    def on_load_trial(sbj: str, trial: int):
        if not sbj or trial is None:
            empty_fig = go.Figure()
            empty_fig.update_layout(title="Please select a subject and trial.")
            return empty_fig
            
        try:
            results = load_trial_data(sbj, trial)
            fig = plot_full_trial_plotly(results)
            return fig
        except Exception as e:
            empty_fig = go.Figure()
            error_msg = "Error: {}".format(str(e))
            empty_fig.update_layout(title=error_msg)
            return empty_fig

    sbj_dropdown.change(
        fn=on_subject_change,
        inputs=sbj_dropdown,
        outputs=trial_dropdown
    )

    load_btn.click(
        fn=on_load_trial,
        inputs=[sbj_dropdown, trial_dropdown],
        outputs=plot_output
    )
    
    # Auto-load on start
    if initial_sbj and initial_trial:
        demo.load(
            fn=on_load_trial,
            inputs=[sbj_dropdown, trial_dropdown],
            outputs=plot_output
        )

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft())