cortexlab-dashboard / live_capture.py
siddhant-rajhans
Add live brain prediction mode (webcam, screen capture, video file)
ab204cc
"""Media capture sources for live brain prediction.
Provides webcam, screen capture, and file streaming sources that
yield frames at a controlled rate for real-time inference.
"""
from __future__ import annotations
import time
import threading
import logging
from pathlib import Path
from collections import deque
from dataclasses import dataclass
import numpy as np
logger = logging.getLogger(__name__)
@dataclass
class MediaFrame:
"""A single frame from any media source."""
video_frame: np.ndarray | None = None # (H, W, 3) RGB
audio_chunk: np.ndarray | None = None # (samples,) float32
timestamp: float = 0.0
class BaseCapture:
"""Base class for media capture sources."""
def __init__(self, fps: float = 1.0):
self.fps = fps
self._running = False
self._buffer: deque[MediaFrame] = deque(maxlen=300)
self._thread: threading.Thread | None = None
self._lock = threading.Lock()
def start(self):
self._running = True
self._thread = threading.Thread(target=self._capture_loop, daemon=True)
self._thread.start()
def stop(self):
self._running = False
if self._thread:
self._thread.join(timeout=3.0)
def get_latest_frame(self) -> MediaFrame | None:
with self._lock:
return self._buffer[-1] if self._buffer else None
def get_all_frames(self) -> list[MediaFrame]:
with self._lock:
frames = list(self._buffer)
return frames
@property
def is_running(self) -> bool:
return self._running
@property
def frame_count(self) -> int:
return len(self._buffer)
def _capture_loop(self):
raise NotImplementedError
class WebcamCapture(BaseCapture):
"""Capture frames from webcam using OpenCV."""
def __init__(self, camera_index: int = 0, fps: float = 1.0, resolution: tuple = (640, 480)):
super().__init__(fps)
self.camera_index = camera_index
self.resolution = resolution
def _capture_loop(self):
try:
import cv2
except ImportError:
logger.error("OpenCV not installed. Run: pip install opencv-python")
self._running = False
return
cap = cv2.VideoCapture(self.camera_index)
if not cap.isOpened():
logger.error(f"Cannot open camera {self.camera_index}")
self._running = False
return
cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.resolution[0])
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.resolution[1])
start_time = time.time()
interval = 1.0 / self.fps
try:
while self._running:
ret, frame = cap.read()
if not ret:
break
# BGR -> RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
media_frame = MediaFrame(
video_frame=frame_rgb,
timestamp=time.time() - start_time,
)
with self._lock:
self._buffer.append(media_frame)
time.sleep(interval)
finally:
cap.release()
class ScreenCapture(BaseCapture):
"""Capture screen frames using mss."""
def __init__(self, fps: float = 1.0, region: dict | None = None):
super().__init__(fps)
self.region = region # {"left": 0, "top": 0, "width": 1920, "height": 1080}
def _capture_loop(self):
try:
import mss
from PIL import Image
except ImportError:
logger.error("mss/PIL not installed. Run: pip install mss Pillow")
self._running = False
return
start_time = time.time()
interval = 1.0 / self.fps
with mss.mss() as sct:
monitor = self.region or sct.monitors[1] # Primary monitor
while self._running:
screenshot = sct.grab(monitor)
img = Image.frombytes("RGB", screenshot.size, screenshot.bgra, "raw", "BGRX")
frame = np.array(img)
media_frame = MediaFrame(
video_frame=frame,
timestamp=time.time() - start_time,
)
with self._lock:
self._buffer.append(media_frame)
time.sleep(interval)
class FileStreamer(BaseCapture):
"""Stream a video file frame-by-frame at real-time speed."""
def __init__(self, file_path: str, fps: float = 1.0):
super().__init__(fps)
self.file_path = file_path
def _capture_loop(self):
try:
import cv2
except ImportError:
logger.error("OpenCV not installed. Run: pip install opencv-python")
self._running = False
return
cap = cv2.VideoCapture(self.file_path)
if not cap.isOpened():
logger.error(f"Cannot open video: {self.file_path}")
self._running = False
return
video_fps = cap.get(cv2.CAP_PROP_FPS) or 30
# Skip frames to match our target FPS
frame_skip = max(1, int(video_fps / self.fps))
frame_idx = 0
start_time = time.time()
interval = 1.0 / self.fps
try:
while self._running:
ret, frame = cap.read()
if not ret:
self._running = False
break
frame_idx += 1
if frame_idx % frame_skip != 0:
continue
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
media_frame = MediaFrame(
video_frame=frame_rgb,
timestamp=time.time() - start_time,
)
with self._lock:
self._buffer.append(media_frame)
time.sleep(interval)
finally:
cap.release()
def get_capture_source(source_type: str, **kwargs) -> BaseCapture:
"""Factory function to create a capture source."""
sources = {
"webcam": WebcamCapture,
"screen": ScreenCapture,
"file": FileStreamer,
}
if source_type not in sources:
raise ValueError(f"Unknown source: {source_type}. Choose from {list(sources)}")
return sources[source_type](**kwargs)