Spaces:
Running
Running
File size: 4,519 Bytes
9b23ae9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | """Shared session state management and data I/O utilities.
Manages cross-page state (selected ROIs, predictions, analysis log)
and provides upload/download widgets.
"""
import io
import json
from datetime import datetime
import numpy as np
import pandas as pd
import streamlit as st
def init_session():
"""Initialize session state with defaults. Safe to call multiple times."""
defaults = {
"brain_predictions": None,
"model_features": {},
"roi_indices": None,
"n_vertices": 0,
"selected_rois": [],
"data_source": "synthetic",
"stimulus_type": "visual",
"tr_seconds": 1.0,
"n_timepoints": 80,
"seed": 42,
"analysis_log": [],
"carry_rois": [], # ROIs carried from another page
}
for key, value in defaults.items():
if key not in st.session_state:
st.session_state[key] = value
def log_analysis(description):
"""Append an entry to the analysis log."""
timestamp = datetime.now().strftime("%H:%M:%S")
entry = f"[{timestamp}] {description}"
if "analysis_log" not in st.session_state:
st.session_state["analysis_log"] = []
st.session_state["analysis_log"].append(entry)
def carry_rois(rois, target_page=""):
"""Store selected ROIs for cross-page workflow."""
st.session_state["carry_rois"] = list(rois)
log_analysis(f"Carried {len(rois)} ROIs to {target_page}")
def get_carried_rois():
"""Retrieve ROIs carried from another page."""
return st.session_state.get("carry_rois", [])
def get_or_generate_data(roi_indices):
"""Get brain predictions from session or generate new synthetic data."""
from synthetic import generate_realistic_predictions
params_key = (
st.session_state.get("n_timepoints", 80),
st.session_state.get("stimulus_type", "visual"),
st.session_state.get("seed", 42),
)
# Check if we need to regenerate
if (
st.session_state.get("brain_predictions") is None
or st.session_state.get("_data_params") != params_key
or st.session_state.get("data_source") == "synthetic"
):
if st.session_state.get("data_source") == "uploaded" and st.session_state.get("brain_predictions") is not None:
return st.session_state["brain_predictions"]
predictions = generate_realistic_predictions(
n_timepoints=st.session_state["n_timepoints"],
roi_indices=roi_indices,
stimulus_type=st.session_state["stimulus_type"],
tr_seconds=st.session_state["tr_seconds"],
seed=st.session_state["seed"],
)
st.session_state["brain_predictions"] = predictions
st.session_state["_data_params"] = params_key
return st.session_state["brain_predictions"]
def upload_npy_widget(label, key):
"""File uploader for .npy arrays with validation."""
uploaded = st.file_uploader(label, type=["npy"], key=key)
if uploaded is not None:
try:
data = np.load(io.BytesIO(uploaded.read()))
st.success(f"Loaded: shape {data.shape}, dtype {data.dtype}")
return data
except Exception as e:
st.error(f"Failed to load file: {e}")
return None
def download_csv_button(df, filename, label="Download CSV"):
"""Download button for a pandas DataFrame as CSV."""
csv = df.to_csv(index=False)
st.download_button(label, csv, filename, "text/csv")
def download_json_button(data, filename, label="Download JSON"):
"""Download button for a dict as JSON."""
json_str = json.dumps(data, indent=2, default=str)
st.download_button(label, json_str, filename, "application/json")
def show_analysis_log():
"""Display the analysis log in the sidebar."""
log = st.session_state.get("analysis_log", [])
if log:
with st.sidebar:
with st.expander("Analysis Log", expanded=False):
for entry in reversed(log[-20:]):
st.caption(entry)
def data_summary_widget(predictions, roi_indices):
"""Show a summary of the current data."""
if predictions is None:
st.info("No data loaded. Generate synthetic data or upload your own.")
return
col1, col2, col3, col4 = st.columns(4)
col1.metric("Timepoints", predictions.shape[0])
col2.metric("Vertices", predictions.shape[1])
col3.metric("ROIs", len(roi_indices))
col4.metric("Source", st.session_state.get("data_source", "synthetic").title())
|