File size: 4,519 Bytes
9b23ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Shared session state management and data I/O utilities.

Manages cross-page state (selected ROIs, predictions, analysis log)
and provides upload/download widgets.
"""

import io
import json
from datetime import datetime

import numpy as np
import pandas as pd
import streamlit as st


def init_session():
    """Initialize session state with defaults. Safe to call multiple times."""
    defaults = {
        "brain_predictions": None,
        "model_features": {},
        "roi_indices": None,
        "n_vertices": 0,
        "selected_rois": [],
        "data_source": "synthetic",
        "stimulus_type": "visual",
        "tr_seconds": 1.0,
        "n_timepoints": 80,
        "seed": 42,
        "analysis_log": [],
        "carry_rois": [],  # ROIs carried from another page
    }
    for key, value in defaults.items():
        if key not in st.session_state:
            st.session_state[key] = value


def log_analysis(description):
    """Append an entry to the analysis log."""
    timestamp = datetime.now().strftime("%H:%M:%S")
    entry = f"[{timestamp}] {description}"
    if "analysis_log" not in st.session_state:
        st.session_state["analysis_log"] = []
    st.session_state["analysis_log"].append(entry)


def carry_rois(rois, target_page=""):
    """Store selected ROIs for cross-page workflow."""
    st.session_state["carry_rois"] = list(rois)
    log_analysis(f"Carried {len(rois)} ROIs to {target_page}")


def get_carried_rois():
    """Retrieve ROIs carried from another page."""
    return st.session_state.get("carry_rois", [])


def get_or_generate_data(roi_indices):
    """Get brain predictions from session or generate new synthetic data."""
    from synthetic import generate_realistic_predictions

    params_key = (
        st.session_state.get("n_timepoints", 80),
        st.session_state.get("stimulus_type", "visual"),
        st.session_state.get("seed", 42),
    )

    # Check if we need to regenerate
    if (
        st.session_state.get("brain_predictions") is None
        or st.session_state.get("_data_params") != params_key
        or st.session_state.get("data_source") == "synthetic"
    ):
        if st.session_state.get("data_source") == "uploaded" and st.session_state.get("brain_predictions") is not None:
            return st.session_state["brain_predictions"]

        predictions = generate_realistic_predictions(
            n_timepoints=st.session_state["n_timepoints"],
            roi_indices=roi_indices,
            stimulus_type=st.session_state["stimulus_type"],
            tr_seconds=st.session_state["tr_seconds"],
            seed=st.session_state["seed"],
        )
        st.session_state["brain_predictions"] = predictions
        st.session_state["_data_params"] = params_key

    return st.session_state["brain_predictions"]


def upload_npy_widget(label, key):
    """File uploader for .npy arrays with validation."""
    uploaded = st.file_uploader(label, type=["npy"], key=key)
    if uploaded is not None:
        try:
            data = np.load(io.BytesIO(uploaded.read()))
            st.success(f"Loaded: shape {data.shape}, dtype {data.dtype}")
            return data
        except Exception as e:
            st.error(f"Failed to load file: {e}")
    return None


def download_csv_button(df, filename, label="Download CSV"):
    """Download button for a pandas DataFrame as CSV."""
    csv = df.to_csv(index=False)
    st.download_button(label, csv, filename, "text/csv")


def download_json_button(data, filename, label="Download JSON"):
    """Download button for a dict as JSON."""
    json_str = json.dumps(data, indent=2, default=str)
    st.download_button(label, json_str, filename, "application/json")


def show_analysis_log():
    """Display the analysis log in the sidebar."""
    log = st.session_state.get("analysis_log", [])
    if log:
        with st.sidebar:
            with st.expander("Analysis Log", expanded=False):
                for entry in reversed(log[-20:]):
                    st.caption(entry)


def data_summary_widget(predictions, roi_indices):
    """Show a summary of the current data."""
    if predictions is None:
        st.info("No data loaded. Generate synthetic data or upload your own.")
        return

    col1, col2, col3, col4 = st.columns(4)
    col1.metric("Timepoints", predictions.shape[0])
    col2.metric("Vertices", predictions.shape[1])
    col3.metric("ROIs", len(roi_indices))
    col4.metric("Source", st.session_state.get("data_source", "synthetic").title())