Spaces:
Running
Running
| """Real-time brain prediction engine. | |
| Runs in a background thread, consuming frames from a capture source, | |
| extracting features, and producing brain predictions via TRIBE v2. | |
| When CortexLab is not installed, falls back to a simulation mode that | |
| generates synthetic predictions from frame statistics. | |
| """ | |
| from __future__ import annotations | |
| import time | |
| import threading | |
| import logging | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| from live_capture import BaseCapture, MediaFrame | |
| logger = logging.getLogger(__name__) | |
| # Check if CortexLab is available | |
| try: | |
| from cortexlab.inference.predictor import TribeModel | |
| CORTEXLAB_AVAILABLE = True | |
| except ImportError: | |
| CORTEXLAB_AVAILABLE = False | |
| class LivePrediction: | |
| """A single prediction with metadata.""" | |
| vertex_data: np.ndarray # (n_vertices,) | |
| timestamp: float | |
| cognitive_load: dict[str, float] = field(default_factory=dict) | |
| processing_time_ms: float = 0.0 | |
| class LiveMetrics: | |
| """Aggregated metrics from the live engine.""" | |
| fps: float = 0.0 | |
| total_frames: int = 0 | |
| total_predictions: int = 0 | |
| avg_latency_ms: float = 0.0 | |
| is_running: bool = False | |
| mode: str = "simulation" # "simulation" or "cortexlab" | |
| class LiveInferenceEngine: | |
| """Background engine for real-time brain prediction. | |
| Consumes frames from a capture source and produces brain predictions. | |
| If CortexLab is installed and a GPU is available, uses the real TRIBE v2 | |
| model. Otherwise, falls back to simulation mode that generates plausible | |
| predictions from frame statistics. | |
| """ | |
| def __init__( | |
| self, | |
| n_vertices: int = 580, | |
| roi_indices: dict | None = None, | |
| buffer_size: int = 120, | |
| checkpoint: str = "facebook/tribev2", | |
| device: str = "auto", | |
| cache_folder: str = "./cache", | |
| ): | |
| self.n_vertices = n_vertices | |
| self.roi_indices = roi_indices or {} | |
| self.buffer_size = buffer_size | |
| self.checkpoint = checkpoint | |
| self.device = device | |
| self.cache_folder = cache_folder | |
| self._predictions: deque[LivePrediction] = deque(maxlen=buffer_size) | |
| self._running = False | |
| self._thread: threading.Thread | None = None | |
| self._lock = threading.Lock() | |
| self._model = None | |
| self._metrics = LiveMetrics() | |
| self._capture: BaseCapture | None = None | |
| def start(self, capture: BaseCapture): | |
| """Start the inference engine with a media capture source.""" | |
| if self._running: | |
| return | |
| self._capture = capture | |
| self._running = True | |
| self._metrics = LiveMetrics(is_running=True) | |
| # Try to load CortexLab model | |
| if CORTEXLAB_AVAILABLE: | |
| try: | |
| logger.info("Loading TRIBE v2 model...") | |
| self._model = TribeModel.from_pretrained( | |
| self.checkpoint, device=self.device, cache_folder=self.cache_folder | |
| ) | |
| self._metrics.mode = "cortexlab" | |
| logger.info("Model loaded. Using real inference.") | |
| except Exception as e: | |
| logger.warning(f"Failed to load model: {e}. Using simulation mode.") | |
| self._model = None | |
| self._metrics.mode = "simulation" | |
| else: | |
| self._metrics.mode = "simulation" | |
| capture.start() | |
| self._thread = threading.Thread(target=self._inference_loop, daemon=True) | |
| self._thread.start() | |
| def stop(self): | |
| """Stop the engine and capture source.""" | |
| self._running = False | |
| if self._capture: | |
| self._capture.stop() | |
| if self._thread: | |
| self._thread.join(timeout=5.0) | |
| self._metrics.is_running = False | |
| def get_latest_prediction(self) -> LivePrediction | None: | |
| with self._lock: | |
| return self._predictions[-1] if self._predictions else None | |
| def get_predictions(self, n: int = 60) -> list[LivePrediction]: | |
| with self._lock: | |
| return list(self._predictions)[-n:] | |
| def get_metrics(self) -> LiveMetrics: | |
| return self._metrics | |
| def _inference_loop(self): | |
| """Main loop: consume frames, produce predictions.""" | |
| frame_times = deque(maxlen=30) | |
| last_frame_count = 0 | |
| while self._running: | |
| frame = self._capture.get_latest_frame() | |
| if frame is None: | |
| time.sleep(0.1) | |
| continue | |
| # Skip if we already processed this frame | |
| current_count = self._capture.frame_count | |
| if current_count == last_frame_count: | |
| time.sleep(0.05) | |
| continue | |
| last_frame_count = current_count | |
| start = time.time() | |
| if self._model is not None and self._metrics.mode == "cortexlab": | |
| prediction = self._run_real_inference(frame) | |
| else: | |
| prediction = self._run_simulation(frame) | |
| elapsed_ms = (time.time() - start) * 1000 | |
| prediction.processing_time_ms = elapsed_ms | |
| with self._lock: | |
| self._predictions.append(prediction) | |
| # Update metrics | |
| frame_times.append(time.time()) | |
| self._metrics.total_predictions += 1 | |
| self._metrics.total_frames = current_count | |
| self._metrics.avg_latency_ms = elapsed_ms | |
| if len(frame_times) >= 2: | |
| self._metrics.fps = (len(frame_times) - 1) / (frame_times[-1] - frame_times[0]) | |
| # Check if capture stopped (file ended) | |
| if not self._capture.is_running: | |
| self._running = False | |
| self._metrics.is_running = False | |
| def _run_real_inference(self, frame: MediaFrame) -> LivePrediction: | |
| """Run actual TRIBE v2 inference on a frame. | |
| For real-time, we skip the full pipeline (get_events_dataframe) | |
| and use a simplified feature extraction path. | |
| """ | |
| import tempfile | |
| import os | |
| try: | |
| # Save frame as temporary video (1 frame) | |
| import cv2 | |
| tmp_path = os.path.join(tempfile.gettempdir(), "cortexlab_live_frame.mp4") | |
| h, w = frame.video_frame.shape[:2] | |
| out = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'mp4v'), 1, (w, h)) | |
| out.write(cv2.cvtColor(frame.video_frame, cv2.COLOR_RGB2BGR)) | |
| out.release() | |
| events = self._model.get_events_dataframe(video_path=tmp_path) | |
| preds, _ = self._model.predict(events, verbose=False) | |
| vertex_data = preds.mean(axis=0) if preds.ndim == 2 else preds | |
| # Normalize to [0, 1] | |
| vmin, vmax = vertex_data.min(), vertex_data.max() | |
| if vmax > vmin: | |
| vertex_data = (vertex_data - vmin) / (vmax - vmin) | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| logger.warning(f"Inference failed: {e}. Falling back to simulation.") | |
| return self._run_simulation(frame) | |
| cog_load = self._compute_cognitive_load(vertex_data) | |
| return LivePrediction( | |
| vertex_data=vertex_data, | |
| timestamp=frame.timestamp, | |
| cognitive_load=cog_load, | |
| ) | |
| def _run_simulation(self, frame: MediaFrame) -> LivePrediction: | |
| """Generate plausible predictions from frame statistics. | |
| Uses frame brightness/color as proxy for visual complexity, | |
| creating biologically-inspired activation patterns. | |
| """ | |
| rng = np.random.default_rng(int(frame.timestamp * 1000) % (2**31)) | |
| # Base noise | |
| vertex_data = rng.standard_normal(self.n_vertices) * 0.03 | |
| if frame.video_frame is not None: | |
| img = frame.video_frame.astype(np.float32) / 255.0 | |
| # Visual complexity from image statistics | |
| brightness = img.mean() | |
| contrast = img.std() | |
| color_variance = img.var(axis=(0, 1)).mean() | |
| # Map to ROI activations | |
| for roi_name, vertices in self.roi_indices.items(): | |
| valid = vertices[vertices < self.n_vertices] | |
| if len(valid) == 0: | |
| continue | |
| # Visual ROIs respond to brightness/contrast | |
| if roi_name in ["V1", "V2", "V3", "V4", "MT", "MST", "FFC", "VVC"]: | |
| activation = contrast * 0.8 + color_variance * 0.5 | |
| # Auditory ROIs get low baseline | |
| elif roi_name in ["A1", "LBelt", "MBelt", "PBelt", "A4", "A5"]: | |
| activation = 0.05 + rng.random() * 0.1 | |
| # Language ROIs moderate | |
| elif roi_name in ["44", "45", "IFJa", "IFJp", "TPOJ1", "TPOJ2"]: | |
| activation = brightness * 0.3 | |
| # Executive ROIs track change | |
| elif roi_name in ["46", "9-46d", "8Av", "8Ad", "FEF"]: | |
| activation = contrast * 0.5 | |
| else: | |
| activation = 0.1 | |
| vertex_data[valid] = activation + rng.standard_normal(len(valid)) * 0.05 | |
| vertex_data = np.clip(vertex_data, 0, 1) | |
| cog_load = self._compute_cognitive_load(vertex_data) | |
| return LivePrediction( | |
| vertex_data=vertex_data, | |
| timestamp=frame.timestamp, | |
| cognitive_load=cog_load, | |
| ) | |
| def _compute_cognitive_load(self, vertex_data: np.ndarray) -> dict[str, float]: | |
| """Compute cognitive load dimensions from vertex data.""" | |
| from utils import COGNITIVE_DIMENSIONS | |
| baseline = max(float(np.median(np.abs(vertex_data))), 1e-8) | |
| scores = {} | |
| for dim, rois in COGNITIVE_DIMENSIONS.items(): | |
| vals = [] | |
| for roi in rois: | |
| if roi in self.roi_indices: | |
| verts = self.roi_indices[roi] | |
| valid = verts[verts < len(vertex_data)] | |
| if len(valid) > 0: | |
| vals.append(np.abs(vertex_data[valid]).mean()) | |
| scores[dim] = min(float(np.mean(vals)) / baseline, 1.0) if vals else 0.0 | |
| scores["Overall"] = float(np.mean(list(scores.values()))) if scores else 0.0 | |
| return scores | |