BEYOND / backend.py
opsecsystems's picture
Upload 3 files
bc0830d verified
"""
EEG Mental Imagery Classification Backend
Full-stack server: OpenBCI Cyton+Daisy (16ch), LSL markers, neurofeedback, websocket API
"""
import asyncio
import json
import logging
import math
import os
import queue
import threading
import time
from collections import deque
from dataclasses import asdict, dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
import uvicorn
# ─── Optional heavy imports (graceful degradation) ───────────────────────────
try:
from brainflow.board_shim import BoardShim, BrainFlowInputParams, BoardIds, BrainFlowError
from brainflow.data_filter import DataFilter, FilterTypes, DetrendOperations, WindowOperations
BRAINFLOW_AVAILABLE = True
except ImportError:
BRAINFLOW_AVAILABLE = False
logging.warning("BrainFlow not installed – running in SIMULATION mode")
try:
from pylsl import StreamInfo, StreamOutlet, StreamInlet, resolve_streams, cf_string, cf_float32
LSL_AVAILABLE = True
except ImportError:
LSL_AVAILABLE = False
logging.warning("pylsl not installed – LSL streaming disabled")
try:
import scipy.signal as signal
from scipy.signal import welch, butter, filtfilt
SCIPY_AVAILABLE = True
except ImportError:
SCIPY_AVAILABLE = False
try:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
SKLEARN_AVAILABLE = True
except ImportError:
SKLEARN_AVAILABLE = False
# ─── Logging ─────────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(asctime)s β”‚ %(levelname)s β”‚ %(message)s")
logger = logging.getLogger("EEG-Backend")
# ─── Constants ────────────────────────────────────────────────────────────────
SAMPLE_RATE = 250 # Hz – Cyton+Daisy
N_CHANNELS = 16
EPOCH_DURATION = 4.0 # seconds
BASELINE_DURATION = 1.0 # seconds pre-stimulus
BUFFER_SECONDS = 30
ALPHA_BAND = (8, 13)
BETA_BAND = (13, 30)
THETA_BAND = (4, 8)
GAMMA_BAND = (30, 45)
# Standard 10-20 positions for 16-ch parieto-occipital cap
ELECTRODE_POSITIONS_16CH = {
"P3": (-0.30, 0.50), "Pz": (0.00, 0.58), "P4": (0.30, 0.50),
"P7": (-0.55, 0.42), "P8": (0.55, 0.42),
"PO3": (-0.22, 0.38), "POz": (0.00, 0.38), "PO4": (0.22, 0.38),
"PO7": (-0.40, 0.28), "PO8": (0.40, 0.28),
"O1": (-0.20, 0.18), "Oz": (0.00, 0.18), "O2": (0.20, 0.18),
"CP3": (-0.25, 0.68), "CPz": (0.00, 0.72), "CP4": (0.25, 0.68),
}
CHANNEL_NAMES = list(ELECTRODE_POSITIONS_16CH.keys())
# ─── Data Structures ──────────────────────────────────────────────────────────
class SessionPhase(str, Enum):
IDLE = "IDLE"
BASELINE = "BASELINE"
ACQUISITION = "ACQUISITION"
FEEDBACK = "FEEDBACK"
CONVERGENCE = "CONVERGENCE"
LIBRARY = "LIBRARY"
@dataclass
class Trial:
trial_id: int
class_label: str
onset_time: float
quality_score: Optional[int] = None # 1-5 self-report
eeg_epoch: Optional[np.ndarray] = None # (n_channels, n_samples)
features: Optional[np.ndarray] = None
nf_score: Optional[float] = None # neurofeedback distance [0,1]
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> dict:
return {
"trial_id": self.trial_id,
"class_label": self.class_label,
"onset_time": self.onset_time,
"quality_score": self.quality_score,
"nf_score": self.nf_score,
"timestamp": self.timestamp,
}
@dataclass
class NeuralState:
"""Stable neural state = cluster centroid in feature space"""
class_label: str
centroid: np.ndarray
covariance: np.ndarray
n_trials: int
convergence_score: float # how tight the cluster is [0,1]
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
@dataclass
class SessionConfig:
subject_id: str = "S001"
session_id: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S"))
class_a: str = "Class_A"
class_b: str = "Class_B"
n_trials_per_class: int = 20
min_quality_for_model: int = 4 # only use trials rated β‰₯ 4
feedback_threshold: float = 0.65 # similarity to converge
port: str = "/dev/ttyUSB0"
board_id: int = 2 # 2=Cyton+Daisy
simulate: bool = not BRAINFLOW_AVAILABLE
# ─── Signal Processing ────────────────────────────────────────────────────────
class SignalProcessor:
def __init__(self, srate: int = SAMPLE_RATE, n_channels: int = N_CHANNELS):
self.srate = srate
self.n_channels = n_channels
self._design_filters()
def _design_filters(self):
"""Pre-compute Butterworth filters"""
nyq = self.srate / 2
self.bp_coefs = {}
for name, (lo, hi) in [("broad", (1, 45)), ("alpha", ALPHA_BAND),
("beta", BETA_BAND), ("theta", THETA_BAND)]:
b, a = butter(4, [lo/nyq, hi/nyq], btype="band")
self.bp_coefs[name] = (b, a)
# Notch 50/60 Hz
b60, a60 = butter(4, [58/nyq, 62/nyq], btype="bandstop")
b50, a50 = butter(4, [48/nyq, 52/nyq], btype="bandstop")
self.notch_60 = (b60, a60)
self.notch_50 = (b50, a50)
def preprocess(self, epoch: np.ndarray) -> np.ndarray:
"""Band-pass + notch filter. epoch: (n_ch, n_samples)"""
out = epoch.copy().astype(float)
b, a = self.bp_coefs["broad"]
bn60, an60 = self.notch_60
for ch in range(out.shape[0]):
out[ch] = filtfilt(b, a, out[ch])
out[ch] = filtfilt(bn60, an60, out[ch])
# Remove DC
out[ch] -= out[ch].mean()
return out
def compute_psd(self, epoch: np.ndarray, fmax: float = 50.0) -> Tuple[np.ndarray, np.ndarray]:
"""Welch PSD per channel. Returns (freqs, psd): psd shape (n_ch, n_freqs)"""
nperseg = min(self.srate, epoch.shape[1])
freqs, psd = welch(epoch, fs=self.srate, nperseg=nperseg, axis=1)
mask = freqs <= fmax
return freqs[mask], psd[:, mask]
def band_power(self, epoch: np.ndarray) -> Dict[str, np.ndarray]:
"""Band power per channel per band. Returns dict of (n_ch,) arrays"""
freqs, psd = self.compute_psd(epoch)
powers = {}
for name, (lo, hi) in [("alpha", ALPHA_BAND), ("beta", BETA_BAND),
("theta", THETA_BAND), ("gamma", GAMMA_BAND)]:
idx = np.logical_and(freqs >= lo, freqs <= hi)
powers[name] = np.trapz(psd[:, idx], freqs[idx], axis=1)
return powers
def extract_features(self, epoch: np.ndarray) -> np.ndarray:
"""
Feature vector: band-power ratios + log-band-powers + covariance diagonal
Returns 1D numpy array
"""
clean = self.preprocess(epoch)
bp = self.band_power(clean)
alpha = bp["alpha"] + 1e-10
beta = bp["beta"] + 1e-10
theta = bp["theta"] + 1e-10
gamma = bp["gamma"] + 1e-10
feats = []
feats.extend(np.log(alpha)) # 16 log-alpha
feats.extend(np.log(beta)) # 16 log-beta
feats.extend(np.log(theta)) # 16 log-theta
feats.extend(beta / alpha) # 16 beta/alpha
feats.extend(theta / alpha) # 16 theta/alpha
# Covariance diagonal (channel variances)
cov = np.cov(clean)
feats.extend(np.log(np.diag(cov) + 1e-10)) # 16 log-var
# Frontal asymmetry (not applicable for occipital, skip ratio index)
return np.array(feats, dtype=float)
def compute_topomap(self, epoch: np.ndarray) -> Dict[str, float]:
"""Return per-channel alpha power for topomap display"""
clean = self.preprocess(epoch)
bp = self.band_power(clean)
return {name: float(val) for name, val in zip(CHANNEL_NAMES, bp["alpha"])}
def artifact_rejection(self, epoch: np.ndarray, threshold_uv: float = 100.0) -> bool:
"""Returns True if epoch is clean"""
peak_to_peak = epoch.max(axis=1) - epoch.min(axis=1)
return bool(np.all(peak_to_peak < threshold_uv * 1e-6)) # BrainFlow in Volts
# ─── Neurofeedback Engine ─────────────────────────────────────────────────────
class NeurofeedbackEngine:
"""
Computes similarity between current epoch features and stored neural state centroids.
Score ∈ [0,1] where 1 = perfectly matching the target state.
"""
def __init__(self):
self.states: Dict[str, NeuralState] = {}
def update_state(self, class_label: str, features_list: List[np.ndarray]):
"""Build/update stable state from a list of high-quality feature vectors"""
if len(features_list) < 3:
return
mat = np.stack(features_list)
centroid = mat.mean(axis=0)
cov = np.cov(mat.T) + np.eye(mat.shape[1]) * 1e-6
# Convergence = 1 - normalised mean intra-cluster distance
dists = [np.linalg.norm(f - centroid) for f in features_list]
convergence = max(0.0, 1.0 - np.mean(dists) / (np.std(dists) + 1.0))
self.states[class_label] = NeuralState(
class_label=class_label,
centroid=centroid,
covariance=cov,
n_trials=len(features_list),
convergence_score=float(convergence),
)
logger.info(f"Neural state updated for {class_label}: convergence={convergence:.3f}")
def compute_similarity(self, features: np.ndarray, class_label: str) -> float:
"""Mahalanobis-based similarity to target state"""
if class_label not in self.states:
return 0.0
state = self.states[class_label]
diff = features - state.centroid
try:
inv_cov = np.linalg.pinv(state.covariance)
dist = float(np.sqrt(diff @ inv_cov @ diff))
# Sigmoid mapping: dist=0 β†’ score=1, dist=5 β†’ scoreβ‰ˆ0.5
score = 1.0 / (1.0 + dist / 3.0)
except Exception:
score = 0.0
return float(np.clip(score, 0.0, 1.0))
def get_feedback_audio_params(self, score: float) -> dict:
"""Returns audio synthesis params based on NF score"""
freq_hz = 200 + score * 600 # 200–800 Hz
volume = 0.3 + score * 0.7 # 0.3–1.0
tone = "positive" if score > 0.65 else ("neutral" if score > 0.35 else "negative")
return {"freq_hz": round(freq_hz, 1), "volume": round(volume, 3), "tone": tone}
# ─── LDA Classifier ───────────────────────────────────────────────────────────
class EEGClassifier:
def __init__(self):
if SKLEARN_AVAILABLE:
self.model = Pipeline([
("scaler", StandardScaler()),
("lda", LinearDiscriminantAnalysis(solver="svd")),
])
self.is_trained = False
self.classes_: List[str] = []
def fit(self, X: np.ndarray, y: List[str]):
if not SKLEARN_AVAILABLE or len(X) < 6:
return False
self.model.fit(X, y)
self.is_trained = True
self.classes_ = list(np.unique(y))
logger.info(f"LDA trained on {len(X)} epochs, classes={self.classes_}")
return True
def predict_proba(self, features: np.ndarray) -> Dict[str, float]:
if not self.is_trained:
return {}
proba = self.model.predict_proba(features.reshape(1, -1))[0]
return {cls: float(p) for cls, p in zip(self.classes_, proba)}
# ─── LSL Manager ──────────────────────────────────────────────────────────────
class LSLManager:
def __init__(self, stream_name: str = "EEGMarkers"):
self.outlet = None
self.eeg_outlet = None
if LSL_AVAILABLE:
self._init_marker_stream(stream_name)
self._init_eeg_stream()
def _init_marker_stream(self, name: str):
info = StreamInfo(name, "Markers", 1, 0, cf_string, f"markers-{name}")
info.desc().append_child_value("manufacturer", "EEG-MI-Backend")
self.outlet = StreamOutlet(info)
logger.info(f"LSL marker stream '{name}' opened")
def _init_eeg_stream(self):
info = StreamInfo("EEG-MI", "EEG", N_CHANNELS, SAMPLE_RATE, cf_float32, "eeg-mi")
chns = info.desc().append_child("channels")
for name in CHANNEL_NAMES:
ch = chns.append_child("channel")
ch.append_child_value("label", name)
ch.append_child_value("unit", "microvolts")
ch.append_child_value("type", "EEG")
self.eeg_outlet = StreamOutlet(info, 32, 360)
logger.info("LSL EEG stream opened")
def push_marker(self, marker: str):
if self.outlet:
self.outlet.push_sample([marker])
logger.debug(f"LSL marker: {marker}")
def push_eeg_chunk(self, chunk: np.ndarray):
"""chunk: (n_samples, n_channels)"""
if self.eeg_outlet and chunk.size > 0:
self.eeg_outlet.push_chunk(chunk.tolist())
# ─── Board Manager ────────────────────────────────────────────────────────────
class BoardManager:
"""Handles Cyton+Daisy board or simulation"""
def __init__(self, config: SessionConfig):
self.config = config
self.board = None
self.board_id = config.board_id
self.srate = SAMPLE_RATE
self._running = False
self._thread: Optional[threading.Thread] = None
self._buffer = deque(maxlen=BUFFER_SECONDS * SAMPLE_RATE)
self._lock = threading.Lock()
self.connected = False
# Simulation state
self._sim_t = 0.0
self._sim_phase = 0.0
def connect(self) -> bool:
if self.config.simulate:
logger.info("Board: SIMULATION mode active")
self.connected = True
return True
if not BRAINFLOW_AVAILABLE:
logger.error("BrainFlow not available; cannot connect to hardware")
return False
params = BrainFlowInputParams()
params.serial_port = self.config.port
try:
BoardShim.enable_dev_board_logger()
self.board = BoardShim(self.board_id, params)
self.board.prepare_session()
self.board.start_stream(45000)
self.srate = BoardShim.get_sampling_rate(self.board_id)
self.connected = True
logger.info(f"Cyton+Daisy connected: {self.config.port} @ {self.srate} Hz")
return True
except BrainFlowError as e:
logger.error(f"Board connection failed: {e}")
return False
def disconnect(self):
self._running = False
if self.board and BRAINFLOW_AVAILABLE:
try:
self.board.stop_stream()
self.board.release_session()
except Exception:
pass
self.connected = False
def start_acquisition(self):
self._running = True
self._thread = threading.Thread(target=self._acquisition_loop, daemon=True)
self._thread.start()
def stop_acquisition(self):
self._running = False
def _acquisition_loop(self):
while self._running:
if self.config.simulate:
chunk = self._generate_sim_chunk(50) # 50 samples at a time
else:
chunk = self._read_board_chunk()
if chunk is not None and chunk.shape[1] > 0:
with self._lock:
for s in range(chunk.shape[1]):
self._buffer.append(chunk[:, s])
time.sleep(0.05)
def _read_board_chunk(self) -> Optional[np.ndarray]:
try:
data = self.board.get_board_data()
eeg_channels = BoardShim.get_eeg_channels(self.board_id)
if data.shape[1] == 0:
return None
return data[eeg_channels[:N_CHANNELS], :] # (16, n_samples)
except Exception as e:
logger.warning(f"Board read error: {e}")
return None
def _generate_sim_chunk(self, n_samples: int) -> np.ndarray:
"""Realistic EEG simulation: alpha rhythm + noise + ocular artifacts"""
chunk = np.zeros((N_CHANNELS, n_samples))
t = np.linspace(self._sim_t, self._sim_t + n_samples/SAMPLE_RATE, n_samples)
self._sim_t += n_samples / SAMPLE_RATE
for ch in range(N_CHANNELS):
# Alpha (8-12 Hz) with spatial gradient
alpha_amp = (0.5 + 0.5 * math.sin(ch * 0.4)) * 15e-6
alpha = alpha_amp * np.sin(2 * np.pi * 10 * t + ch * 0.3)
# Beta (13-25 Hz)
beta = 5e-6 * np.sin(2 * np.pi * 20 * t + ch * 0.6)
# Theta (4-8 Hz) – stronger in frontal (not this cap, but simulated)
theta = 8e-6 * np.sin(2 * np.pi * 6 * t)
# White noise
noise = np.random.randn(n_samples) * 3e-6
chunk[ch] = alpha + beta + theta + noise
# Occasional blink artifact on ch 0-1
if np.random.rand() < 0.02:
idx = np.random.randint(0, max(1, n_samples - 20))
pulse = np.hanning(20) * 150e-6
end = min(idx + 20, n_samples)
chunk[0, idx:end] += pulse[:end-idx]
chunk[1, idx:end] += pulse[:end-idx] * 0.5
return chunk
def get_epoch(self, duration: float, offset: float = 0.0) -> Optional[np.ndarray]:
"""Extract latest epoch of `duration` seconds. Returns (n_ch, n_samples)"""
n_wanted = int((duration + offset) * SAMPLE_RATE)
with self._lock:
buf = list(self._buffer)
if len(buf) < n_wanted:
return None
arr = np.array(buf[-n_wanted:]).T # (n_ch, n_samples)
# Return only the epoch portion (skip offset)
n_offset = int(offset * SAMPLE_RATE)
return arr[:, n_offset:]
def get_psd_snapshot(self) -> Optional[np.ndarray]:
"""Latest 2-second window for live display (n_ch, n_samples)"""
return self.get_epoch(2.0)
# ─── Session Manager ──────────────────────────────────────────────────────────
class SessionManager:
def __init__(self):
self.config = SessionConfig()
self.phase = SessionPhase.IDLE
self.trials: List[Trial] = []
self.trial_counter = 0
self.current_class: Optional[str] = None
self.target_class: Optional[str] = None
self.processor = SignalProcessor()
self.nf_engine = NeurofeedbackEngine()
self.classifier = EEGClassifier()
self.lsl = LSLManager()
self.board = BoardManager(self.config)
self._ws_clients: List[WebSocket] = []
self._event_queue: asyncio.Queue = asyncio.Queue()
self._current_loop: Optional[asyncio.AbstractEventLoop] = None
# ── Connection Management ──────────────────────────────────────────────
def connect_board(self) -> dict:
ok = self.board.connect()
if ok:
self.board.start_acquisition()
self.lsl.push_marker("SESSION_START")
return {"status": "connected" if ok else "error", "simulate": self.config.simulate}
def disconnect_board(self):
self.board.disconnect()
self.lsl.push_marker("SESSION_END")
# ── Trial Management ───────────────────────────────────────────────────
def start_trial(self, class_label: str) -> Trial:
self.trial_counter += 1
onset = time.time()
trial = Trial(
trial_id=self.trial_counter,
class_label=class_label,
onset_time=onset,
)
self.trials.append(trial)
self.current_class = class_label
self.phase = SessionPhase.ACQUISITION
marker = f"TRIAL_START;class={class_label};id={self.trial_counter}"
self.lsl.push_marker(marker)
logger.info(f"Trial {self.trial_counter} started: {class_label}")
return trial
def end_trial(self, quality: int) -> dict:
"""Called after subject rates the trial 1-5"""
active = self._get_active_trial()
if not active:
return {"error": "no active trial"}
active.quality_score = quality
self.lsl.push_marker(f"TRIAL_END;id={active.trial_id};quality={quality}")
# Extract epoch (4s before now, skip first 0.5s baseline)
epoch = self.board.get_epoch(EPOCH_DURATION, offset=0.5)
if epoch is not None:
# Artifact check
clean = self.processor.artifact_rejection(epoch)
if clean:
active.eeg_epoch = epoch
active.features = self.processor.extract_features(epoch)
logger.info(f"Trial {active.trial_id}: clean epoch extracted, quality={quality}")
else:
logger.warning(f"Trial {active.trial_id}: artifact detected – epoch discarded")
self.lsl.push_marker(f"ARTIFACT;id={active.trial_id}")
# Update model if quality β‰₯ threshold
self._maybe_update_model(active.class_label)
# Compute NF score if model exists
nf = 0.0
if active.features is not None:
nf = self.nf_engine.compute_similarity(active.features, active.class_label)
active.nf_score = nf
return {
"trial_id": active.trial_id,
"quality": quality,
"nf_score": round(nf, 4),
"feedback": self.nf_engine.get_feedback_audio_params(nf),
"model_updated": active.class_label in self.nf_engine.states,
}
def _get_active_trial(self) -> Optional[Trial]:
for t in reversed(self.trials):
if t.quality_score is None:
return t
return None
def _maybe_update_model(self, class_label: str):
"""Rebuild neural state from high-quality trials"""
good_features = [
t.features for t in self.trials
if t.class_label == class_label
and t.quality_score is not None
and t.quality_score >= self.config.min_quality_for_model
and t.features is not None
]
if len(good_features) >= 3:
self.nf_engine.update_state(class_label, good_features)
# Retrain classifier if both classes have data
self._retrain_classifier()
def _retrain_classifier(self):
X, y = [], []
for t in self.trials:
if t.features is not None and t.quality_score is not None \
and t.quality_score >= self.config.min_quality_for_model:
X.append(t.features)
y.append(t.class_label)
if len(set(y)) >= 2:
self.classifier.fit(np.array(X), y)
# ── Live Signals ───────────────────────────────────────────────────────
def get_live_signals(self) -> dict:
epoch = self.board.get_psd_snapshot()
if epoch is None:
# Return synthetic zeros
return {
"channels": CHANNEL_NAMES,
"alpha_power": [0.0] * N_CHANNELS,
"beta_power": [0.0] * N_CHANNELS,
"theta_power": [0.0] * N_CHANNELS,
"raw_samples": [[0.0] * 50] * 4, # 4 channels preview
"topomap": {n: 0.0 for n in CHANNEL_NAMES},
"signal_quality": [0.0] * N_CHANNELS,
}
try:
clean = self.processor.preprocess(epoch)
bp = self.processor.band_power(clean)
topo = self.processor.compute_topomap(epoch)
# Signal quality: inverse of variance relative to expected range
peak2peak = (epoch.max(axis=1) - epoch.min(axis=1)) * 1e6 # Β΅V
quality = [float(np.clip(1.0 - (pp - 10) / 90.0, 0.0, 1.0)) for pp in peak2peak]
# Raw traces for 4 selected channels (Β΅V)
sel = [0, 4, 8, 12]
raw = (clean[sel, -50:] * 1e6).tolist() # last 200ms
return {
"channels": CHANNEL_NAMES,
"alpha_power": [float(v * 1e12) for v in bp["alpha"]],
"beta_power": [float(v * 1e12) for v in bp["beta"]],
"theta_power": [float(v * 1e12) for v in bp["theta"]],
"raw_samples": raw,
"topomap": {k: float(v * 1e12) for k, v in topo.items()},
"signal_quality": quality,
}
except Exception as e:
logger.warning(f"Live signal error: {e}")
return {}
# ── Session State ──────────────────────────────────────────────────────
def get_state(self) -> dict:
class_a_trials = [t.to_dict() for t in self.trials if t.class_label == self.config.class_a]
class_b_trials = [t.to_dict() for t in self.trials if t.class_label == self.config.class_b]
states = {}
for label, state in self.nf_engine.states.items():
states[label] = {
"n_trials": state.n_trials,
"convergence_score": round(state.convergence_score, 4),
"created_at": state.created_at,
}
return {
"phase": self.phase.value,
"config": {
"subject_id": self.config.subject_id,
"session_id": self.config.session_id,
"class_a": self.config.class_a,
"class_b": self.config.class_b,
"simulate": self.config.simulate,
"min_quality_for_model": self.config.min_quality_for_model,
},
"class_a_trials": class_a_trials,
"class_b_trials": class_b_trials,
"neural_states": states,
"classifier_trained": self.classifier.is_trained,
"board_connected": self.board.connected,
"total_trials": len(self.trials),
}
# ── WebSocket Broadcast ────────────────────────────────────────────────
def add_ws_client(self, ws: WebSocket):
self._ws_clients.append(ws)
def remove_ws_client(self, ws: WebSocket):
if ws in self._ws_clients:
self._ws_clients.remove(ws)
async def broadcast(self, data: dict):
dead = []
for ws in self._ws_clients:
try:
await ws.send_json(data)
except Exception:
dead.append(ws)
for ws in dead:
self.remove_ws_client(ws)
async def live_broadcast_loop(self):
"""Continuously push live EEG data to all connected WS clients"""
while True:
await asyncio.sleep(0.1) # 10 Hz update
if self._ws_clients and self.board.connected:
live = self.get_live_signals()
if live:
live["type"] = "live_eeg"
live["timestamp"] = time.time()
await self.broadcast(live)
# ─── FastAPI App ──────────────────────────────────────────────────────────────
ROOT_DIR = Path(__file__).resolve().parent
SITE_HTML = ROOT_DIR / "site.html"
app = FastAPI(title="EEG Mental Imagery Backend", version="2.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
session = SessionManager()
@app.get("/")
def serve_ui():
"""Serve the project UI (site.html) from the same origin as the API / WebSocket."""
if not SITE_HTML.is_file():
return JSONResponse(
status_code=404,
content={"error": "site.html not found", "path": str(SITE_HTML)},
)
return FileResponse(SITE_HTML, media_type="text/html; charset=utf-8")
@app.on_event("startup")
async def startup():
asyncio.create_task(session.live_broadcast_loop())
logger.info("EEG Backend started")
# ── REST Endpoints ─────────────────────────────────────────────────────────
@app.get("/health")
def health():
return {"status": "ok", "brainflow": BRAINFLOW_AVAILABLE, "lsl": LSL_AVAILABLE,
"sklearn": SKLEARN_AVAILABLE, "scipy": SCIPY_AVAILABLE}
@app.post("/board/connect")
def connect_board(port: str = "/dev/ttyUSB0", simulate: bool = False):
session.config.port = port
session.config.simulate = simulate or not BRAINFLOW_AVAILABLE
return session.connect_board()
@app.post("/board/disconnect")
def disconnect_board():
session.disconnect_board()
return {"status": "disconnected"}
@app.post("/session/configure")
def configure_session(
subject_id: str = "S001",
class_a: str = "Apple",
class_b: str = "House",
min_quality: int = 4
):
session.config.subject_id = subject_id
session.config.class_a = class_a
session.config.class_b = class_b
session.config.min_quality_for_model = min_quality
session.config.session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
return {"status": "configured", "config": asdict(session.config)}
@app.post("/trial/start")
def start_trial(class_label: str):
trial = session.start_trial(class_label)
return trial.to_dict()
@app.post("/trial/end")
def end_trial(quality: int):
return session.end_trial(quality)
@app.get("/state")
def get_state():
return session.get_state()
@app.get("/signals/live")
def live_signals():
return session.get_live_signals()
@app.get("/channels/info")
def channel_info():
return {
"n_channels": N_CHANNELS,
"channel_names": CHANNEL_NAMES,
"positions": ELECTRODE_POSITIONS_16CH,
"sample_rate": SAMPLE_RATE,
}
@app.post("/lsl/marker")
def push_marker(marker: str):
session.lsl.push_marker(marker)
return {"pushed": marker, "timestamp": time.time()}
@app.get("/classifier/predict")
def predict_current():
epoch = session.board.get_epoch(EPOCH_DURATION)
if epoch is None:
return {"error": "no data"}
feats = session.processor.extract_features(epoch)
proba = session.classifier.predict_proba(feats)
nf_a = session.nf_engine.compute_similarity(feats, session.config.class_a)
nf_b = session.nf_engine.compute_similarity(feats, session.config.class_b)
return {
"probabilities": proba,
"nf_similarity": {session.config.class_a: nf_a, session.config.class_b: nf_b},
"timestamp": time.time(),
}
@app.post("/session/reset")
def reset_session():
session.trials.clear()
session.trial_counter = 0
session.nf_engine.states.clear()
session.classifier.is_trained = False
session.phase = SessionPhase.IDLE
return {"status": "reset"}
# ── WebSocket ──────────────────────────────────────────────────────────────
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
await ws.accept()
session.add_ws_client(ws)
logger.info(f"WebSocket client connected. Total: {len(session._ws_clients)}")
try:
while True:
msg = await ws.receive_json()
msg_type = msg.get("type", "")
if msg_type == "ping":
await ws.send_json({"type": "pong", "timestamp": time.time()})
elif msg_type == "state":
await ws.send_json({"type": "state", **session.get_state()})
elif msg_type == "start_trial":
trial = session.start_trial(msg["class_label"])
await ws.send_json({"type": "trial_started", **trial.to_dict()})
elif msg_type == "end_trial":
result = session.end_trial(msg["quality"])
await ws.send_json({"type": "trial_ended", **result})
# Broadcast state update to all clients
await session.broadcast({"type": "state_update", **session.get_state()})
elif msg_type == "configure":
session.config.subject_id = msg.get("subject_id", session.config.subject_id)
session.config.class_a = msg.get("class_a", session.config.class_a)
session.config.class_b = msg.get("class_b", session.config.class_b)
mq = msg.get("min_quality")
if mq is not None:
session.config.min_quality_for_model = int(mq)
await ws.send_json({"type": "configured"})
elif msg_type == "connect_board":
session.config.simulate = msg.get("simulate", True)
if msg.get("port"):
session.config.port = str(msg["port"])
result = session.connect_board()
await ws.send_json({"type": "board_status", **result})
except WebSocketDisconnect:
session.remove_ws_client(ws)
logger.info(f"WebSocket client disconnected. Remaining: {len(session._ws_clients)}")
# ─── Entry Point ──────────────────────────────────────────────────────────────
if __name__ == "__main__":
uvicorn.run("backend:app", host="0.0.0.0", port=8765, reload=False, log_level="info")