siddhant-rajhans commited on
Commit
ab204cc
·
1 Parent(s): 4f96c68

Add live brain prediction mode (webcam, screen capture, video file)

Browse files

New page: 6_Live_Inference.py
- Real-time brain activation from webcam, screen capture, or uploaded video
- 3D brain visualization updating live with Inferno colormap
- Rolling cognitive load timeline (Visual/Auditory/Language/Executive)
- Glowing metric cards updating every prediction
- Start/Stop controls with FPS and latency indicators
- Status bar showing mode (simulation/cortexlab), FPS, prediction count
- Predictions stored in session state for use in all other analysis pages
- Simulation mode works without CortexLab (predictions from image statistics)
- Real mode uses TRIBE v2 when CortexLab + GPU available

New utilities:
- live_capture.py: WebcamCapture (OpenCV), ScreenCapture (mss), FileStreamer
- All run in background threads, yield MediaFrame objects
- Configurable FPS, thread-safe buffer
- live_engine.py: LiveInferenceEngine
- Background thread consuming frames and producing predictions
- Dual mode: CortexLab real inference or simulation fallback
- Computes cognitive load dimensions from vertex activations
- Tracks metrics (FPS, latency, prediction count)

Updated home page with Live Inference feature card

Files changed (5) hide show
  1. Home.py +4 -10
  2. app.py +4 -10
  3. live_capture.py +208 -0
  4. live_engine.py +284 -0
  5. pages/6_Live_Inference.py +294 -0
Home.py CHANGED
@@ -97,17 +97,11 @@ with col5:
97
 
98
  with col6:
99
  st.markdown(feature_card(
100
- "", "Streaming Inference",
101
- "Real-time sliding-window predictions for BCI pipelines. Cross-subject adaptation with minimal calibration data.",
102
- "#F59E0B"
103
  ), unsafe_allow_html=True)
104
- st.markdown(f"""
105
- <a href="https://github.com/siddhant-rajhans/cortexlab" target="_blank" style="
106
- display: inline-block; padding: 0.4rem 1rem;
107
- color: #F59E0B; font-size: 0.85rem;
108
- text-decoration: none;
109
- ">View on GitHub &rarr;</a>
110
- """, unsafe_allow_html=True)
111
 
112
  # --- Data Config (collapsed) ---
113
  st.markdown("<div style='height: 1rem'></div>", unsafe_allow_html=True)
 
97
 
98
  with col6:
99
  st.markdown(feature_card(
100
+ "🔴", "Live Inference",
101
+ "Real-time brain prediction from webcam, screen capture, or video. All metrics update live. Works in simulation mode or with full CortexLab + GPU.",
102
+ "#EF4444"
103
  ), unsafe_allow_html=True)
104
+ st.page_link("pages/6_Live_Inference.py", label="Open Live Inference")
 
 
 
 
 
 
105
 
106
  # --- Data Config (collapsed) ---
107
  st.markdown("<div style='height: 1rem'></div>", unsafe_allow_html=True)
app.py CHANGED
@@ -97,17 +97,11 @@ with col5:
97
 
98
  with col6:
99
  st.markdown(feature_card(
100
- "", "Streaming Inference",
101
- "Real-time sliding-window predictions for BCI pipelines. Cross-subject adaptation with minimal calibration data.",
102
- "#F59E0B"
103
  ), unsafe_allow_html=True)
104
- st.markdown(f"""
105
- <a href="https://github.com/siddhant-rajhans/cortexlab" target="_blank" style="
106
- display: inline-block; padding: 0.4rem 1rem;
107
- color: #F59E0B; font-size: 0.85rem;
108
- text-decoration: none;
109
- ">View on GitHub &rarr;</a>
110
- """, unsafe_allow_html=True)
111
 
112
  # --- Data Config (collapsed) ---
113
  st.markdown("<div style='height: 1rem'></div>", unsafe_allow_html=True)
 
97
 
98
  with col6:
99
  st.markdown(feature_card(
100
+ "🔴", "Live Inference",
101
+ "Real-time brain prediction from webcam, screen capture, or video. All metrics update live. Works in simulation mode or with full CortexLab + GPU.",
102
+ "#EF4444"
103
  ), unsafe_allow_html=True)
104
+ st.page_link("pages/6_Live_Inference.py", label="Open Live Inference")
 
 
 
 
 
 
105
 
106
  # --- Data Config (collapsed) ---
107
  st.markdown("<div style='height: 1rem'></div>", unsafe_allow_html=True)
live_capture.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Media capture sources for live brain prediction.
2
+
3
+ Provides webcam, screen capture, and file streaming sources that
4
+ yield frames at a controlled rate for real-time inference.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import time
10
+ import threading
11
+ import logging
12
+ from pathlib import Path
13
+ from collections import deque
14
+ from dataclasses import dataclass
15
+
16
+ import numpy as np
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @dataclass
22
+ class MediaFrame:
23
+ """A single frame from any media source."""
24
+ video_frame: np.ndarray | None = None # (H, W, 3) RGB
25
+ audio_chunk: np.ndarray | None = None # (samples,) float32
26
+ timestamp: float = 0.0
27
+
28
+
29
+ class BaseCapture:
30
+ """Base class for media capture sources."""
31
+
32
+ def __init__(self, fps: float = 1.0):
33
+ self.fps = fps
34
+ self._running = False
35
+ self._buffer: deque[MediaFrame] = deque(maxlen=300)
36
+ self._thread: threading.Thread | None = None
37
+ self._lock = threading.Lock()
38
+
39
+ def start(self):
40
+ self._running = True
41
+ self._thread = threading.Thread(target=self._capture_loop, daemon=True)
42
+ self._thread.start()
43
+
44
+ def stop(self):
45
+ self._running = False
46
+ if self._thread:
47
+ self._thread.join(timeout=3.0)
48
+
49
+ def get_latest_frame(self) -> MediaFrame | None:
50
+ with self._lock:
51
+ return self._buffer[-1] if self._buffer else None
52
+
53
+ def get_all_frames(self) -> list[MediaFrame]:
54
+ with self._lock:
55
+ frames = list(self._buffer)
56
+ return frames
57
+
58
+ @property
59
+ def is_running(self) -> bool:
60
+ return self._running
61
+
62
+ @property
63
+ def frame_count(self) -> int:
64
+ return len(self._buffer)
65
+
66
+ def _capture_loop(self):
67
+ raise NotImplementedError
68
+
69
+
70
+ class WebcamCapture(BaseCapture):
71
+ """Capture frames from webcam using OpenCV."""
72
+
73
+ def __init__(self, camera_index: int = 0, fps: float = 1.0, resolution: tuple = (640, 480)):
74
+ super().__init__(fps)
75
+ self.camera_index = camera_index
76
+ self.resolution = resolution
77
+
78
+ def _capture_loop(self):
79
+ try:
80
+ import cv2
81
+ except ImportError:
82
+ logger.error("OpenCV not installed. Run: pip install opencv-python")
83
+ self._running = False
84
+ return
85
+
86
+ cap = cv2.VideoCapture(self.camera_index)
87
+ if not cap.isOpened():
88
+ logger.error(f"Cannot open camera {self.camera_index}")
89
+ self._running = False
90
+ return
91
+
92
+ cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.resolution[0])
93
+ cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.resolution[1])
94
+ start_time = time.time()
95
+ interval = 1.0 / self.fps
96
+
97
+ try:
98
+ while self._running:
99
+ ret, frame = cap.read()
100
+ if not ret:
101
+ break
102
+ # BGR -> RGB
103
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
104
+ media_frame = MediaFrame(
105
+ video_frame=frame_rgb,
106
+ timestamp=time.time() - start_time,
107
+ )
108
+ with self._lock:
109
+ self._buffer.append(media_frame)
110
+ time.sleep(interval)
111
+ finally:
112
+ cap.release()
113
+
114
+
115
+ class ScreenCapture(BaseCapture):
116
+ """Capture screen frames using mss."""
117
+
118
+ def __init__(self, fps: float = 1.0, region: dict | None = None):
119
+ super().__init__(fps)
120
+ self.region = region # {"left": 0, "top": 0, "width": 1920, "height": 1080}
121
+
122
+ def _capture_loop(self):
123
+ try:
124
+ import mss
125
+ from PIL import Image
126
+ except ImportError:
127
+ logger.error("mss/PIL not installed. Run: pip install mss Pillow")
128
+ self._running = False
129
+ return
130
+
131
+ start_time = time.time()
132
+ interval = 1.0 / self.fps
133
+
134
+ with mss.mss() as sct:
135
+ monitor = self.region or sct.monitors[1] # Primary monitor
136
+ while self._running:
137
+ screenshot = sct.grab(monitor)
138
+ img = Image.frombytes("RGB", screenshot.size, screenshot.bgra, "raw", "BGRX")
139
+ frame = np.array(img)
140
+ media_frame = MediaFrame(
141
+ video_frame=frame,
142
+ timestamp=time.time() - start_time,
143
+ )
144
+ with self._lock:
145
+ self._buffer.append(media_frame)
146
+ time.sleep(interval)
147
+
148
+
149
+ class FileStreamer(BaseCapture):
150
+ """Stream a video file frame-by-frame at real-time speed."""
151
+
152
+ def __init__(self, file_path: str, fps: float = 1.0):
153
+ super().__init__(fps)
154
+ self.file_path = file_path
155
+
156
+ def _capture_loop(self):
157
+ try:
158
+ import cv2
159
+ except ImportError:
160
+ logger.error("OpenCV not installed. Run: pip install opencv-python")
161
+ self._running = False
162
+ return
163
+
164
+ cap = cv2.VideoCapture(self.file_path)
165
+ if not cap.isOpened():
166
+ logger.error(f"Cannot open video: {self.file_path}")
167
+ self._running = False
168
+ return
169
+
170
+ video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
171
+ # Skip frames to match our target FPS
172
+ frame_skip = max(1, int(video_fps / self.fps))
173
+ frame_idx = 0
174
+ start_time = time.time()
175
+ interval = 1.0 / self.fps
176
+
177
+ try:
178
+ while self._running:
179
+ ret, frame = cap.read()
180
+ if not ret:
181
+ self._running = False
182
+ break
183
+ frame_idx += 1
184
+ if frame_idx % frame_skip != 0:
185
+ continue
186
+
187
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
188
+ media_frame = MediaFrame(
189
+ video_frame=frame_rgb,
190
+ timestamp=time.time() - start_time,
191
+ )
192
+ with self._lock:
193
+ self._buffer.append(media_frame)
194
+ time.sleep(interval)
195
+ finally:
196
+ cap.release()
197
+
198
+
199
+ def get_capture_source(source_type: str, **kwargs) -> BaseCapture:
200
+ """Factory function to create a capture source."""
201
+ sources = {
202
+ "webcam": WebcamCapture,
203
+ "screen": ScreenCapture,
204
+ "file": FileStreamer,
205
+ }
206
+ if source_type not in sources:
207
+ raise ValueError(f"Unknown source: {source_type}. Choose from {list(sources)}")
208
+ return sources[source_type](**kwargs)
live_engine.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real-time brain prediction engine.
2
+
3
+ Runs in a background thread, consuming frames from a capture source,
4
+ extracting features, and producing brain predictions via TRIBE v2.
5
+
6
+ When CortexLab is not installed, falls back to a simulation mode that
7
+ generates synthetic predictions from frame statistics.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import time
13
+ import threading
14
+ import logging
15
+ from collections import deque
16
+ from dataclasses import dataclass, field
17
+
18
+ import numpy as np
19
+
20
+ from live_capture import BaseCapture, MediaFrame
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Check if CortexLab is available
25
+ try:
26
+ from cortexlab.inference.predictor import TribeModel
27
+ CORTEXLAB_AVAILABLE = True
28
+ except ImportError:
29
+ CORTEXLAB_AVAILABLE = False
30
+
31
+
32
+ @dataclass
33
+ class LivePrediction:
34
+ """A single prediction with metadata."""
35
+ vertex_data: np.ndarray # (n_vertices,)
36
+ timestamp: float
37
+ cognitive_load: dict[str, float] = field(default_factory=dict)
38
+ processing_time_ms: float = 0.0
39
+
40
+
41
+ @dataclass
42
+ class LiveMetrics:
43
+ """Aggregated metrics from the live engine."""
44
+ fps: float = 0.0
45
+ total_frames: int = 0
46
+ total_predictions: int = 0
47
+ avg_latency_ms: float = 0.0
48
+ is_running: bool = False
49
+ mode: str = "simulation" # "simulation" or "cortexlab"
50
+
51
+
52
+ class LiveInferenceEngine:
53
+ """Background engine for real-time brain prediction.
54
+
55
+ Consumes frames from a capture source and produces brain predictions.
56
+ If CortexLab is installed and a GPU is available, uses the real TRIBE v2
57
+ model. Otherwise, falls back to simulation mode that generates plausible
58
+ predictions from frame statistics.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ n_vertices: int = 580,
64
+ roi_indices: dict | None = None,
65
+ buffer_size: int = 120,
66
+ checkpoint: str = "facebook/tribev2",
67
+ device: str = "auto",
68
+ cache_folder: str = "./cache",
69
+ ):
70
+ self.n_vertices = n_vertices
71
+ self.roi_indices = roi_indices or {}
72
+ self.buffer_size = buffer_size
73
+ self.checkpoint = checkpoint
74
+ self.device = device
75
+ self.cache_folder = cache_folder
76
+
77
+ self._predictions: deque[LivePrediction] = deque(maxlen=buffer_size)
78
+ self._running = False
79
+ self._thread: threading.Thread | None = None
80
+ self._lock = threading.Lock()
81
+ self._model = None
82
+ self._metrics = LiveMetrics()
83
+ self._capture: BaseCapture | None = None
84
+
85
+ def start(self, capture: BaseCapture):
86
+ """Start the inference engine with a media capture source."""
87
+ if self._running:
88
+ return
89
+
90
+ self._capture = capture
91
+ self._running = True
92
+ self._metrics = LiveMetrics(is_running=True)
93
+
94
+ # Try to load CortexLab model
95
+ if CORTEXLAB_AVAILABLE:
96
+ try:
97
+ logger.info("Loading TRIBE v2 model...")
98
+ self._model = TribeModel.from_pretrained(
99
+ self.checkpoint, device=self.device, cache_folder=self.cache_folder
100
+ )
101
+ self._metrics.mode = "cortexlab"
102
+ logger.info("Model loaded. Using real inference.")
103
+ except Exception as e:
104
+ logger.warning(f"Failed to load model: {e}. Using simulation mode.")
105
+ self._model = None
106
+ self._metrics.mode = "simulation"
107
+ else:
108
+ self._metrics.mode = "simulation"
109
+
110
+ capture.start()
111
+ self._thread = threading.Thread(target=self._inference_loop, daemon=True)
112
+ self._thread.start()
113
+
114
+ def stop(self):
115
+ """Stop the engine and capture source."""
116
+ self._running = False
117
+ if self._capture:
118
+ self._capture.stop()
119
+ if self._thread:
120
+ self._thread.join(timeout=5.0)
121
+ self._metrics.is_running = False
122
+
123
+ def get_latest_prediction(self) -> LivePrediction | None:
124
+ with self._lock:
125
+ return self._predictions[-1] if self._predictions else None
126
+
127
+ def get_predictions(self, n: int = 60) -> list[LivePrediction]:
128
+ with self._lock:
129
+ return list(self._predictions)[-n:]
130
+
131
+ def get_metrics(self) -> LiveMetrics:
132
+ return self._metrics
133
+
134
+ def _inference_loop(self):
135
+ """Main loop: consume frames, produce predictions."""
136
+ frame_times = deque(maxlen=30)
137
+ last_frame_count = 0
138
+
139
+ while self._running:
140
+ frame = self._capture.get_latest_frame()
141
+ if frame is None:
142
+ time.sleep(0.1)
143
+ continue
144
+
145
+ # Skip if we already processed this frame
146
+ current_count = self._capture.frame_count
147
+ if current_count == last_frame_count:
148
+ time.sleep(0.05)
149
+ continue
150
+ last_frame_count = current_count
151
+
152
+ start = time.time()
153
+
154
+ if self._model is not None and self._metrics.mode == "cortexlab":
155
+ prediction = self._run_real_inference(frame)
156
+ else:
157
+ prediction = self._run_simulation(frame)
158
+
159
+ elapsed_ms = (time.time() - start) * 1000
160
+ prediction.processing_time_ms = elapsed_ms
161
+
162
+ with self._lock:
163
+ self._predictions.append(prediction)
164
+
165
+ # Update metrics
166
+ frame_times.append(time.time())
167
+ self._metrics.total_predictions += 1
168
+ self._metrics.total_frames = current_count
169
+ self._metrics.avg_latency_ms = elapsed_ms
170
+ if len(frame_times) >= 2:
171
+ self._metrics.fps = (len(frame_times) - 1) / (frame_times[-1] - frame_times[0])
172
+
173
+ # Check if capture stopped (file ended)
174
+ if not self._capture.is_running:
175
+ self._running = False
176
+ self._metrics.is_running = False
177
+
178
+ def _run_real_inference(self, frame: MediaFrame) -> LivePrediction:
179
+ """Run actual TRIBE v2 inference on a frame.
180
+
181
+ For real-time, we skip the full pipeline (get_events_dataframe)
182
+ and use a simplified feature extraction path.
183
+ """
184
+ import tempfile
185
+ import os
186
+
187
+ try:
188
+ # Save frame as temporary video (1 frame)
189
+ import cv2
190
+ tmp_path = os.path.join(tempfile.gettempdir(), "cortexlab_live_frame.mp4")
191
+ h, w = frame.video_frame.shape[:2]
192
+ out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (w, h))
193
+ out.write(cv2.cvtColor(frame.video_frame, cv2.COLOR_RGB2BGR))
194
+ out.release()
195
+
196
+ events = self._model.get_events_dataframe(video_path=tmp_path)
197
+ preds, _ = self._model.predict(events, verbose=False)
198
+ vertex_data = preds.mean(axis=0) if preds.ndim == 2 else preds
199
+
200
+ # Normalize to [0, 1]
201
+ vmin, vmax = vertex_data.min(), vertex_data.max()
202
+ if vmax > vmin:
203
+ vertex_data = (vertex_data - vmin) / (vmax - vmin)
204
+
205
+ os.unlink(tmp_path)
206
+ except Exception as e:
207
+ logger.warning(f"Inference failed: {e}. Falling back to simulation.")
208
+ return self._run_simulation(frame)
209
+
210
+ cog_load = self._compute_cognitive_load(vertex_data)
211
+ return LivePrediction(
212
+ vertex_data=vertex_data,
213
+ timestamp=frame.timestamp,
214
+ cognitive_load=cog_load,
215
+ )
216
+
217
+ def _run_simulation(self, frame: MediaFrame) -> LivePrediction:
218
+ """Generate plausible predictions from frame statistics.
219
+
220
+ Uses frame brightness/color as proxy for visual complexity,
221
+ creating biologically-inspired activation patterns.
222
+ """
223
+ rng = np.random.default_rng(int(frame.timestamp * 1000) % (2**31))
224
+
225
+ # Base noise
226
+ vertex_data = rng.standard_normal(self.n_vertices) * 0.03
227
+
228
+ if frame.video_frame is not None:
229
+ img = frame.video_frame.astype(np.float32) / 255.0
230
+
231
+ # Visual complexity from image statistics
232
+ brightness = img.mean()
233
+ contrast = img.std()
234
+ color_variance = img.var(axis=(0, 1)).mean()
235
+
236
+ # Map to ROI activations
237
+ for roi_name, vertices in self.roi_indices.items():
238
+ valid = vertices[vertices < self.n_vertices]
239
+ if len(valid) == 0:
240
+ continue
241
+
242
+ # Visual ROIs respond to brightness/contrast
243
+ if roi_name in ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"]:
244
+ activation = contrast * 0.8 + color_variance * 0.5
245
+ # Auditory ROIs get low baseline
246
+ elif roi_name in ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"]:
247
+ activation = 0.05 + rng.random() * 0.1
248
+ # Language ROIs moderate
249
+ elif roi_name in ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2"]:
250
+ activation = brightness * 0.3
251
+ # Executive ROIs track change
252
+ elif roi_name in ["46", "9-46d", "8Av", "8Ad", "FEF"]:
253
+ activation = contrast * 0.5
254
+ else:
255
+ activation = 0.1
256
+
257
+ vertex_data[valid] = activation + rng.standard_normal(len(valid)) * 0.05
258
+
259
+ vertex_data = np.clip(vertex_data, 0, 1)
260
+ cog_load = self._compute_cognitive_load(vertex_data)
261
+
262
+ return LivePrediction(
263
+ vertex_data=vertex_data,
264
+ timestamp=frame.timestamp,
265
+ cognitive_load=cog_load,
266
+ )
267
+
268
+ def _compute_cognitive_load(self, vertex_data: np.ndarray) -> dict[str, float]:
269
+ """Compute cognitive load dimensions from vertex data."""
270
+ from utils import COGNITIVE_DIMENSIONS
271
+
272
+ baseline = max(float(np.median(np.abs(vertex_data))), 1e-8)
273
+ scores = {}
274
+ for dim, rois in COGNITIVE_DIMENSIONS.items():
275
+ vals = []
276
+ for roi in rois:
277
+ if roi in self.roi_indices:
278
+ verts = self.roi_indices[roi]
279
+ valid = verts[verts < len(vertex_data)]
280
+ if len(valid) > 0:
281
+ vals.append(np.abs(vertex_data[valid]).mean())
282
+ scores[dim] = min(float(np.mean(vals)) / baseline, 1.0) if vals else 0.0
283
+ scores["Overall"] = float(np.mean(list(scores.values()))) if scores else 0.0
284
+ return scores
pages/6_Live_Inference.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Live Brain Prediction - Real-Time Inference from Webcam, Screen, or Video."""
2
+
3
+ import time
4
+
5
+ import numpy as np
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+ from plotly.subplots import make_subplots
9
+
10
+ from session import init_session, show_analysis_log
11
+ from theme import inject_theme, glow_card, section_header
12
+ from utils import make_roi_indices, COGNITIVE_DIMENSIONS
13
+
14
+ st.set_page_config(page_title="Live Inference", page_icon="🔴", layout="wide")
15
+ init_session()
16
+ inject_theme()
17
+ show_analysis_log()
18
+
19
+ st.title("🔴 Live Brain Prediction")
20
+ st.markdown("Real-time brain activation prediction from webcam, screen capture, or video file.")
21
+
22
+ # --- Check Dependencies ---
23
+ deps_ok = True
24
+ missing = []
25
+
26
+ try:
27
+ from live_capture import WebcamCapture, ScreenCapture, FileStreamer, get_capture_source
28
+ from live_engine import LiveInferenceEngine, CORTEXLAB_AVAILABLE
29
+ except ImportError as e:
30
+ deps_ok = False
31
+ missing.append(str(e))
32
+
33
+ # --- Sidebar ---
34
+ with st.sidebar:
35
+ st.header("Live Inference")
36
+
37
+ source_type = st.selectbox("Source", ["webcam", "screen", "file"],
38
+ format_func={"webcam": "Webcam + Mic", "screen": "Screen Capture", "file": "Video File"}.get)
39
+
40
+ if source_type == "file":
41
+ uploaded_file = st.file_uploader("Upload video", type=["mp4", "avi", "mkv", "mov", "webm"])
42
+
43
+ st.subheader("Settings")
44
+ capture_fps = st.slider("Capture FPS", 0.5, 5.0, 1.0, 0.5,
45
+ help="Frames per second. Higher = more responsive but more CPU/GPU load.")
46
+
47
+ if CORTEXLAB_AVAILABLE:
48
+ device = st.selectbox("Device", ["auto", "cuda", "cpu"])
49
+ st.success("CortexLab detected. Real inference available.")
50
+ else:
51
+ device = "cpu"
52
+ st.warning("CortexLab not installed. Running in **simulation mode** (predictions from image statistics).")
53
+ with st.expander("Install CortexLab"):
54
+ st.code("pip install -e ../cortexlab[analysis]", language="bash")
55
+
56
+ st.subheader("Display")
57
+ show_brain_3d = st.checkbox("Show 3D brain", value=True)
58
+ show_timeline = st.checkbox("Show cognitive load timeline", value=True)
59
+ timeline_window = st.slider("Timeline window (seconds)", 10, 120, 60)
60
+
61
+ # --- Initialize Engine ---
62
+ roi_indices, n_vertices = make_roi_indices()
63
+
64
+ if "live_engine" not in st.session_state:
65
+ st.session_state["live_engine"] = None
66
+ if "live_running" not in st.session_state:
67
+ st.session_state["live_running"] = False
68
+
69
+ # --- Controls ---
70
+ col_start, col_stop, col_status = st.columns([1, 1, 2])
71
+
72
+ with col_start:
73
+ start_clicked = st.button("▶ Start", type="primary", use_container_width=True,
74
+ disabled=st.session_state.get("live_running", False))
75
+
76
+ with col_stop:
77
+ stop_clicked = st.button("⬛ Stop", use_container_width=True,
78
+ disabled=not st.session_state.get("live_running", False))
79
+
80
+ # Handle Start
81
+ if start_clicked and deps_ok:
82
+ # Create capture source
83
+ if source_type == "webcam":
84
+ capture = WebcamCapture(fps=capture_fps)
85
+ elif source_type == "screen":
86
+ capture = ScreenCapture(fps=capture_fps)
87
+ elif source_type == "file":
88
+ if uploaded_file is not None:
89
+ import tempfile, os
90
+ tmp_path = os.path.join(tempfile.gettempdir(), uploaded_file.name)
91
+ with open(tmp_path, "wb") as f:
92
+ f.write(uploaded_file.read())
93
+ capture = FileStreamer(file_path=tmp_path, fps=capture_fps)
94
+ else:
95
+ st.error("Upload a video file first.")
96
+ st.stop()
97
+
98
+ # Create and start engine
99
+ engine = LiveInferenceEngine(
100
+ n_vertices=n_vertices,
101
+ roi_indices=roi_indices,
102
+ device=device,
103
+ )
104
+ engine.start(capture)
105
+ st.session_state["live_engine"] = engine
106
+ st.session_state["live_running"] = True
107
+ st.rerun()
108
+
109
+ # Handle Stop
110
+ if stop_clicked:
111
+ engine = st.session_state.get("live_engine")
112
+ if engine:
113
+ engine.stop()
114
+ st.session_state["live_running"] = False
115
+ st.rerun()
116
+
117
+ # --- Status Bar ---
118
+ with col_status:
119
+ engine = st.session_state.get("live_engine")
120
+ if engine and st.session_state.get("live_running"):
121
+ metrics = engine.get_metrics()
122
+ st.markdown(f"""
123
+ <div style="display: flex; gap: 1.5rem; align-items: center; padding: 0.5rem;">
124
+ <span style="color: #EF4444; font-size: 1.2rem;">● LIVE</span>
125
+ <span style="color: #94A3B8;">Mode: <b style="color: #06B6D4;">{metrics.mode}</b></span>
126
+ <span style="color: #94A3B8;">FPS: <b style="color: #10B981;">{metrics.fps:.1f}</b></span>
127
+ <span style="color: #94A3B8;">Predictions: <b style="color: #A29BFE;">{metrics.total_predictions}</b></span>
128
+ <span style="color: #94A3B8;">Latency: <b style="color: #FFEAA7;">{metrics.avg_latency_ms:.0f}ms</b></span>
129
+ </div>
130
+ """, unsafe_allow_html=True)
131
+ elif not st.session_state.get("live_running"):
132
+ st.markdown('<span style="color: #64748B;">Ready. Select a source and click Start.</span>', unsafe_allow_html=True)
133
+
134
+ st.divider()
135
+
136
+ # --- Live Display ---
137
+ if st.session_state.get("live_running") and engine:
138
+ predictions = engine.get_predictions(timeline_window)
139
+
140
+ if predictions:
141
+ latest = predictions[-1]
142
+
143
+ # --- Cognitive Load Metrics ---
144
+ cog = latest.cognitive_load
145
+ c1, c2, c3, c4, c5 = st.columns(5)
146
+ with c1: glow_card("Overall", f"{cog.get('Overall', 0):.2f}", "", "#7C3AED")
147
+ with c2: glow_card("Visual", f"{cog.get('Visual Complexity', 0):.2f}", "", "#00D2FF")
148
+ with c3: glow_card("Auditory", f"{cog.get('Auditory Demand', 0):.2f}", "", "#FF6B6B")
149
+ with c4: glow_card("Language", f"{cog.get('Language Processing', 0):.2f}", "", "#A29BFE")
150
+ with c5: glow_card("Executive", f"{cog.get('Executive Load', 0):.2f}", "", "#FFEAA7")
151
+
152
+ col_brain, col_timeline = st.columns([1, 1])
153
+
154
+ # --- 3D Brain ---
155
+ if show_brain_3d:
156
+ with col_brain:
157
+ section_header("Brain Activation", f"t = {latest.timestamp:.1f}s")
158
+ try:
159
+ from brain_mesh import (
160
+ load_fsaverage_mesh, render_interactive_3d,
161
+ )
162
+ coords, faces = load_fsaverage_mesh("left", "fsaverage4") # Fast mesh for live
163
+ n_mesh = coords.shape[0]
164
+
165
+ # Map vertex data to mesh size
166
+ vd = latest.vertex_data
167
+ if len(vd) < n_mesh:
168
+ vd = np.interp(np.linspace(0, len(vd) - 1, n_mesh), np.arange(len(vd)), vd)
169
+ elif len(vd) > n_mesh:
170
+ vd = vd[:n_mesh]
171
+
172
+ fig_brain = render_interactive_3d(
173
+ coords, faces, vd, cmap="Inferno", vmin=0, vmax=0.8,
174
+ bg_color="#050510", initial_view="Lateral Left",
175
+ )
176
+ if fig_brain:
177
+ fig_brain.update_layout(height=400, margin=dict(l=0, r=0, t=0, b=0))
178
+ st.plotly_chart(fig_brain, use_container_width=True)
179
+ except Exception as e:
180
+ st.warning(f"Brain render error: {e}")
181
+
182
+ # --- Cognitive Load Timeline ---
183
+ if show_timeline:
184
+ with col_timeline:
185
+ section_header("Cognitive Load Timeline", f"{len(predictions)} data points")
186
+
187
+ fig_tl = go.Figure()
188
+ timestamps = [p.timestamp for p in predictions]
189
+ dim_colors = {
190
+ "Visual Complexity": "#00D2FF",
191
+ "Auditory Demand": "#FF6B6B",
192
+ "Language Processing": "#A29BFE",
193
+ "Executive Load": "#FFEAA7",
194
+ }
195
+
196
+ for dim, color in dim_colors.items():
197
+ values = [p.cognitive_load.get(dim, 0) for p in predictions]
198
+ fig_tl.add_trace(go.Scatter(
199
+ x=timestamps, y=values, name=dim.split()[0],
200
+ line=dict(color=color, width=2), mode="lines",
201
+ ))
202
+
203
+ fig_tl.update_layout(
204
+ xaxis_title="Time (seconds)", yaxis_title="Load",
205
+ yaxis_range=[0, 1.05], height=400,
206
+ template="plotly_dark",
207
+ legend=dict(orientation="h", yanchor="bottom", y=1.02),
208
+ margin=dict(l=40, r=10, t=10, b=40),
209
+ )
210
+ st.plotly_chart(fig_tl, use_container_width=True)
211
+
212
+ # --- Store latest predictions for other pages ---
213
+ all_vertex_data = np.array([p.vertex_data for p in predictions])
214
+ st.session_state["brain_predictions"] = all_vertex_data
215
+ st.session_state["roi_indices"] = roi_indices
216
+ st.session_state["data_source"] = "live_inference"
217
+
218
+ # --- Navigation ---
219
+ st.divider()
220
+ st.markdown("**Explore live predictions in other tools:**")
221
+ c1, c2, c3, c4 = st.columns(4)
222
+ with c1: st.page_link("pages/5_Brain_Viewer.py", label="Brain Viewer", icon="🧠")
223
+ with c2: st.page_link("pages/2_Cognitive_Load.py", label="Cognitive Load", icon="📊")
224
+ with c3: st.page_link("pages/3_Temporal_Dynamics.py", label="Temporal Dynamics", icon="⏱️")
225
+ with c4: st.page_link("pages/4_Connectivity.py", label="Connectivity", icon="🔗")
226
+
227
+ # --- Auto-refresh ---
228
+ time.sleep(1.0)
229
+ st.rerun()
230
+
231
+ else:
232
+ # --- Not running: show instructions ---
233
+ st.markdown("""
234
+ <div style="
235
+ text-align: center; padding: 3rem 2rem;
236
+ background: rgba(15, 15, 40, 0.4);
237
+ border: 1px solid rgba(100, 100, 255, 0.15);
238
+ border-radius: 16px; margin: 1rem 0;
239
+ ">
240
+ <div style="font-size: 3rem; margin-bottom: 1rem;">🧠</div>
241
+ <h3 style="color: #F1F5F9; margin-bottom: 0.5rem;">Ready for Live Brain Prediction</h3>
242
+ <p style="color: #94A3B8; max-width: 600px; margin: 0 auto;">
243
+ Select a source (webcam, screen capture, or video file) from the sidebar,
244
+ then click <b>Start</b> to begin real-time brain activation prediction.
245
+ </p>
246
+ <div style="margin-top: 1.5rem; display: flex; justify-content: center; gap: 2rem;">
247
+ <div style="text-align: center;">
248
+ <div style="font-size: 1.5rem;">📹</div>
249
+ <div style="color: #06B6D4; font-size: 0.85rem; font-weight: 600;">Webcam</div>
250
+ <div style="color: #64748B; font-size: 0.75rem;">Live camera feed</div>
251
+ </div>
252
+ <div style="text-align: center;">
253
+ <div style="font-size: 1.5rem;">🖥️</div>
254
+ <div style="color: #7C3AED; font-size: 0.85rem; font-weight: 600;">Screen</div>
255
+ <div style="color: #64748B; font-size: 0.75rem;">Capture display</div>
256
+ </div>
257
+ <div style="text-align: center;">
258
+ <div style="font-size: 1.5rem;">🎬</div>
259
+ <div style="color: #EC4899; font-size: 0.85rem; font-weight: 600;">Video File</div>
260
+ <div style="color: #64748B; font-size: 0.75rem;">Frame-by-frame</div>
261
+ </div>
262
+ </div>
263
+ </div>
264
+ """, unsafe_allow_html=True)
265
+
266
+ # Show last predictions if available
267
+ if st.session_state.get("brain_predictions") is not None and st.session_state.get("data_source") == "live_inference":
268
+ st.info(f"Previous session predictions available ({st.session_state['brain_predictions'].shape[0]} timepoints). Navigate to analysis pages to explore them.")
269
+
270
+ # --- Methodology ---
271
+ with st.expander("About Live Inference", expanded=False):
272
+ st.markdown(f"""
273
+ **Mode: {'Real (CortexLab)' if CORTEXLAB_AVAILABLE else 'Simulation'}**
274
+
275
+ {'**Real Inference**: Uses TRIBE v2 to extract features (V-JEPA2, Wav2Vec-BERT, LLaMA 3.2) and predict fMRI brain activation at each captured frame. Requires GPU for interactive speed.' if CORTEXLAB_AVAILABLE else '**Simulation Mode**: CortexLab is not installed. Predictions are generated from image statistics (brightness, contrast, color variance) mapped to brain ROIs. This demonstrates the pipeline without requiring GPU or model weights.'}
276
+
277
+ **Sources:**
278
+ - **Webcam**: Captures frames via OpenCV. Requires `pip install opencv-python`.
279
+ - **Screen Capture**: Captures display via mss. Requires `pip install mss Pillow`.
280
+ - **Video File**: Reads uploaded video frame-by-frame at the specified FPS.
281
+
282
+ **Cognitive Load Dimensions** are computed from predicted vertex activations
283
+ grouped by HCP MMP1.0 ROIs (same method as the Cognitive Load Scorer page).
284
+
285
+ **Performance:**
286
+ - Simulation mode: ~1-5ms per frame (CPU)
287
+ - Real inference with GPU: ~50-200ms per frame
288
+ - Real inference with CPU: ~5-30s per frame (not recommended)
289
+
290
+ **To enable real inference:**
291
+ ```bash
292
+ pip install -e path/to/cortexlab[analysis]
293
+ ```
294
+ """)