cortexlab-dashboard / pages /6_Live_Inference.py
siddhant-rajhans
Add live brain prediction mode (webcam, screen capture, video file)
ab204cc
"""Live Brain Prediction - Real-Time Inference from Webcam, Screen, or Video."""
import time
import numpy as np
import plotly.graph_objects as go
import streamlit as st
from plotly.subplots import make_subplots
from session import init_session, show_analysis_log
from theme import inject_theme, glow_card, section_header
from utils import make_roi_indices, COGNITIVE_DIMENSIONS
st.set_page_config(page_title="Live Inference", page_icon="🔴", layout="wide")
init_session()
inject_theme()
show_analysis_log()
st.title("🔴 Live Brain Prediction")
st.markdown("Real-time brain activation prediction from webcam, screen capture, or video file.")
# --- Check Dependencies ---
deps_ok = True
missing = []
try:
from live_capture import WebcamCapture, ScreenCapture, FileStreamer, get_capture_source
from live_engine import LiveInferenceEngine, CORTEXLAB_AVAILABLE
except ImportError as e:
deps_ok = False
missing.append(str(e))
# --- Sidebar ---
with st.sidebar:
st.header("Live Inference")
source_type = st.selectbox("Source", ["webcam", "screen", "file"],
format_func={"webcam": "Webcam + Mic", "screen": "Screen Capture", "file": "Video File"}.get)
if source_type == "file":
uploaded_file = st.file_uploader("Upload video", type=["mp4", "avi", "mkv", "mov", "webm"])
st.subheader("Settings")
capture_fps = st.slider("Capture FPS", 0.5, 5.0, 1.0, 0.5,
help="Frames per second. Higher = more responsive but more CPU/GPU load.")
if CORTEXLAB_AVAILABLE:
device = st.selectbox("Device", ["auto", "cuda", "cpu"])
st.success("CortexLab detected. Real inference available.")
else:
device = "cpu"
st.warning("CortexLab not installed. Running in **simulation mode** (predictions from image statistics).")
with st.expander("Install CortexLab"):
st.code("pip install -e ../cortexlab[analysis]", language="bash")
st.subheader("Display")
show_brain_3d = st.checkbox("Show 3D brain", value=True)
show_timeline = st.checkbox("Show cognitive load timeline", value=True)
timeline_window = st.slider("Timeline window (seconds)", 10, 120, 60)
# --- Initialize Engine ---
roi_indices, n_vertices = make_roi_indices()
if "live_engine" not in st.session_state:
st.session_state["live_engine"] = None
if "live_running" not in st.session_state:
st.session_state["live_running"] = False
# --- Controls ---
col_start, col_stop, col_status = st.columns([1, 1, 2])
with col_start:
start_clicked = st.button("▶ Start", type="primary", use_container_width=True,
disabled=st.session_state.get("live_running", False))
with col_stop:
stop_clicked = st.button("⬛ Stop", use_container_width=True,
disabled=not st.session_state.get("live_running", False))
# Handle Start
if start_clicked and deps_ok:
# Create capture source
if source_type == "webcam":
capture = WebcamCapture(fps=capture_fps)
elif source_type == "screen":
capture = ScreenCapture(fps=capture_fps)
elif source_type == "file":
if uploaded_file is not None:
import tempfile, os
tmp_path = os.path.join(tempfile.gettempdir(), uploaded_file.name)
with open(tmp_path, "wb") as f:
f.write(uploaded_file.read())
capture = FileStreamer(file_path=tmp_path, fps=capture_fps)
else:
st.error("Upload a video file first.")
st.stop()
# Create and start engine
engine = LiveInferenceEngine(
n_vertices=n_vertices,
roi_indices=roi_indices,
device=device,
)
engine.start(capture)
st.session_state["live_engine"] = engine
st.session_state["live_running"] = True
st.rerun()
# Handle Stop
if stop_clicked:
engine = st.session_state.get("live_engine")
if engine:
engine.stop()
st.session_state["live_running"] = False
st.rerun()
# --- Status Bar ---
with col_status:
engine = st.session_state.get("live_engine")
if engine and st.session_state.get("live_running"):
metrics = engine.get_metrics()
st.markdown(f"""
<div style="display: flex; gap: 1.5rem; align-items: center; padding: 0.5rem;">
<span style="color: #EF4444; font-size: 1.2rem;">● LIVE</span>
<span style="color: #94A3B8;">Mode: <b style="color: #06B6D4;">{metrics.mode}</b></span>
<span style="color: #94A3B8;">FPS: <b style="color: #10B981;">{metrics.fps:.1f}</b></span>
<span style="color: #94A3B8;">Predictions: <b style="color: #A29BFE;">{metrics.total_predictions}</b></span>
<span style="color: #94A3B8;">Latency: <b style="color: #FFEAA7;">{metrics.avg_latency_ms:.0f}ms</b></span>
</div>
""", unsafe_allow_html=True)
elif not st.session_state.get("live_running"):
st.markdown('<span style="color: #64748B;">Ready. Select a source and click Start.</span>', unsafe_allow_html=True)
st.divider()
# --- Live Display ---
if st.session_state.get("live_running") and engine:
predictions = engine.get_predictions(timeline_window)
if predictions:
latest = predictions[-1]
# --- Cognitive Load Metrics ---
cog = latest.cognitive_load
c1, c2, c3, c4, c5 = st.columns(5)
with c1: glow_card("Overall", f"{cog.get('Overall', 0):.2f}", "", "#7C3AED")
with c2: glow_card("Visual", f"{cog.get('Visual Complexity', 0):.2f}", "", "#00D2FF")
with c3: glow_card("Auditory", f"{cog.get('Auditory Demand', 0):.2f}", "", "#FF6B6B")
with c4: glow_card("Language", f"{cog.get('Language Processing', 0):.2f}", "", "#A29BFE")
with c5: glow_card("Executive", f"{cog.get('Executive Load', 0):.2f}", "", "#FFEAA7")
col_brain, col_timeline = st.columns([1, 1])
# --- 3D Brain ---
if show_brain_3d:
with col_brain:
section_header("Brain Activation", f"t = {latest.timestamp:.1f}s")
try:
from brain_mesh import (
load_fsaverage_mesh, render_interactive_3d,
)
coords, faces = load_fsaverage_mesh("left", "fsaverage4") # Fast mesh for live
n_mesh = coords.shape[0]
# Map vertex data to mesh size
vd = latest.vertex_data
if len(vd) < n_mesh:
vd = np.interp(np.linspace(0, len(vd) - 1, n_mesh), np.arange(len(vd)), vd)
elif len(vd) > n_mesh:
vd = vd[:n_mesh]
fig_brain = render_interactive_3d(
coords, faces, vd, cmap="Inferno", vmin=0, vmax=0.8,
bg_color="#050510", initial_view="Lateral Left",
)
if fig_brain:
fig_brain.update_layout(height=400, margin=dict(l=0, r=0, t=0, b=0))
st.plotly_chart(fig_brain, use_container_width=True)
except Exception as e:
st.warning(f"Brain render error: {e}")
# --- Cognitive Load Timeline ---
if show_timeline:
with col_timeline:
section_header("Cognitive Load Timeline", f"{len(predictions)} data points")
fig_tl = go.Figure()
timestamps = [p.timestamp for p in predictions]
dim_colors = {
"Visual Complexity": "#00D2FF",
"Auditory Demand": "#FF6B6B",
"Language Processing": "#A29BFE",
"Executive Load": "#FFEAA7",
}
for dim, color in dim_colors.items():
values = [p.cognitive_load.get(dim, 0) for p in predictions]
fig_tl.add_trace(go.Scatter(
x=timestamps, y=values, name=dim.split()[0],
line=dict(color=color, width=2), mode="lines",
))
fig_tl.update_layout(
xaxis_title="Time (seconds)", yaxis_title="Load",
yaxis_range=[0, 1.05], height=400,
template="plotly_dark",
legend=dict(orientation="h", yanchor="bottom", y=1.02),
margin=dict(l=40, r=10, t=10, b=40),
)
st.plotly_chart(fig_tl, use_container_width=True)
# --- Store latest predictions for other pages ---
all_vertex_data = np.array([p.vertex_data for p in predictions])
st.session_state["brain_predictions"] = all_vertex_data
st.session_state["roi_indices"] = roi_indices
st.session_state["data_source"] = "live_inference"
# --- Navigation ---
st.divider()
st.markdown("**Explore live predictions in other tools:**")
c1, c2, c3, c4 = st.columns(4)
with c1: st.page_link("pages/5_Brain_Viewer.py", label="Brain Viewer", icon="🧠")
with c2: st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load", icon="📊")
with c3: st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
with c4: st.page_link("pages/4_Connectivity.py", label="Connectivity", icon="🔗")
# --- Auto-refresh ---
time.sleep(1.0)
st.rerun()
else:
# --- Not running: show instructions ---
st.markdown("""
<div style="
text-align: center; padding: 3rem 2rem;
background: rgba(15, 15, 40, 0.4);
border: 1px solid rgba(100, 100, 255, 0.15);
border-radius: 16px; margin: 1rem 0;
">
<div style="font-size: 3rem; margin-bottom: 1rem;">🧠</div>
<h3 style="color: #F1F5F9; margin-bottom: 0.5rem;">Ready for Live Brain Prediction</h3>
<p style="color: #94A3B8; max-width: 600px; margin: 0 auto;">
Select a source (webcam, screen capture, or video file) from the sidebar,
then click <b>Start</b> to begin real-time brain activation prediction.
</p>
<div style="margin-top: 1.5rem; display: flex; justify-content: center; gap: 2rem;">
<div style="text-align: center;">
<div style="font-size: 1.5rem;">📹</div>
<div style="color: #06B6D4; font-size: 0.85rem; font-weight: 600;">Webcam</div>
<div style="color: #64748B; font-size: 0.75rem;">Live camera feed</div>
</div>
<div style="text-align: center;">
<div style="font-size: 1.5rem;">🖥️</div>
<div style="color: #7C3AED; font-size: 0.85rem; font-weight: 600;">Screen</div>
<div style="color: #64748B; font-size: 0.75rem;">Capture display</div>
</div>
<div style="text-align: center;">
<div style="font-size: 1.5rem;">🎬</div>
<div style="color: #EC4899; font-size: 0.85rem; font-weight: 600;">Video File</div>
<div style="color: #64748B; font-size: 0.75rem;">Frame-by-frame</div>
</div>
</div>
</div>
""", unsafe_allow_html=True)
# Show last predictions if available
if st.session_state.get("brain_predictions") is not None and st.session_state.get("data_source") == "live_inference":
st.info(f"Previous session predictions available ({st.session_state['brain_predictions'].shape[0]} timepoints). Navigate to analysis pages to explore them.")
# --- Methodology ---
with st.expander("About Live Inference", expanded=False):
st.markdown(f"""
**Mode: {'Real (CortexLab)' if CORTEXLAB_AVAILABLE else 'Simulation'}**
{'**Real Inference**: Uses TRIBE v2 to extract features (V-JEPA2, Wav2Vec-BERT, LLaMA 3.2) and predict fMRI brain activation at each captured frame. Requires GPU for interactive speed.' if CORTEXLAB_AVAILABLE else '**Simulation Mode**: CortexLab is not installed. Predictions are generated from image statistics (brightness, contrast, color variance) mapped to brain ROIs. This demonstrates the pipeline without requiring GPU or model weights.'}
**Sources:**
- **Webcam**: Captures frames via OpenCV. Requires `pip install opencv-python`.
- **Screen Capture**: Captures display via mss. Requires `pip install mss Pillow`.
- **Video File**: Reads uploaded video frame-by-frame at the specified FPS.
**Cognitive Load Dimensions** are computed from predicted vertex activations
grouped by HCP MMP1.0 ROIs (same method as the Cognitive Load Scorer page).
**Performance:**
- Simulation mode: ~1-5ms per frame (CPU)
- Real inference with GPU: ~50-200ms per frame
- Real inference with CPU: ~5-30s per frame (not recommended)
**To enable real inference:**
```bash
pip install -e path/to/cortexlab[analysis]
```
""")