""" EEG Mental Imagery Classification Backend Full-stack server: OpenBCI Cyton+Daisy (16ch), LSL markers, neurofeedback, websocket API """ import asyncio import json import logging import math import os import queue import threading import time from collections import deque from dataclasses import asdict, dataclass, field from datetime import datetime from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse import uvicorn # ─── Optional heavy imports (graceful degradation) ─────────────────────────── try: from brainflow.board_shim import BoardShim, BrainFlowInputParams, BoardIds, BrainFlowError from brainflow.data_filter import DataFilter, FilterTypes, DetrendOperations, WindowOperations BRAINFLOW_AVAILABLE = True except ImportError: BRAINFLOW_AVAILABLE = False logging.warning("BrainFlow not installed – running in SIMULATION mode") try: from pylsl import StreamInfo, StreamOutlet, StreamInlet, resolve_streams, cf_string, cf_float32 LSL_AVAILABLE = True except ImportError: LSL_AVAILABLE = False logging.warning("pylsl not installed – LSL streaming disabled") try: import scipy.signal as signal from scipy.signal import welch, butter, filtfilt SCIPY_AVAILABLE = True except ImportError: SCIPY_AVAILABLE = False try: from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.preprocessing import StandardScaler from sklearn.pipeline import Pipeline SKLEARN_AVAILABLE = True except ImportError: SKLEARN_AVAILABLE = False # ─── Logging ───────────────────────────────────────────────────────────────── logging.basicConfig(level=logging.INFO, format="%(asctime)s │ %(levelname)s │ %(message)s") logger = logging.getLogger("EEG-Backend") # ─── Constants ──────────────────────────────────────────────────────────────── SAMPLE_RATE = 250 # Hz – Cyton+Daisy N_CHANNELS = 16 EPOCH_DURATION = 4.0 # seconds BASELINE_DURATION = 1.0 # seconds pre-stimulus BUFFER_SECONDS = 30 ALPHA_BAND = (8, 13) BETA_BAND = (13, 30) THETA_BAND = (4, 8) GAMMA_BAND = (30, 45) # Standard 10-20 positions for 16-ch parieto-occipital cap ELECTRODE_POSITIONS_16CH = { "P3": (-0.30, 0.50), "Pz": (0.00, 0.58), "P4": (0.30, 0.50), "P7": (-0.55, 0.42), "P8": (0.55, 0.42), "PO3": (-0.22, 0.38), "POz": (0.00, 0.38), "PO4": (0.22, 0.38), "PO7": (-0.40, 0.28), "PO8": (0.40, 0.28), "O1": (-0.20, 0.18), "Oz": (0.00, 0.18), "O2": (0.20, 0.18), "CP3": (-0.25, 0.68), "CPz": (0.00, 0.72), "CP4": (0.25, 0.68), } CHANNEL_NAMES = list(ELECTRODE_POSITIONS_16CH.keys()) # ─── Data Structures ────────────────────────────────────────────────────────── class SessionPhase(str, Enum): IDLE = "IDLE" BASELINE = "BASELINE" ACQUISITION = "ACQUISITION" FEEDBACK = "FEEDBACK" CONVERGENCE = "CONVERGENCE" LIBRARY = "LIBRARY" @dataclass class Trial: trial_id: int class_label: str onset_time: float quality_score: Optional[int] = None # 1-5 self-report eeg_epoch: Optional[np.ndarray] = None # (n_channels, n_samples) features: Optional[np.ndarray] = None nf_score: Optional[float] = None # neurofeedback distance [0,1] timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) def to_dict(self) -> dict: return { "trial_id": self.trial_id, "class_label": self.class_label, "onset_time": self.onset_time, "quality_score": self.quality_score, "nf_score": self.nf_score, "timestamp": self.timestamp, } @dataclass class NeuralState: """Stable neural state = cluster centroid in feature space""" class_label: str centroid: np.ndarray covariance: np.ndarray n_trials: int convergence_score: float # how tight the cluster is [0,1] created_at: str = field(default_factory=lambda: datetime.now().isoformat()) @dataclass class SessionConfig: subject_id: str = "S001" session_id: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S")) class_a: str = "Class_A" class_b: str = "Class_B" n_trials_per_class: int = 20 min_quality_for_model: int = 4 # only use trials rated ≥ 4 feedback_threshold: float = 0.65 # similarity to converge port: str = "/dev/ttyUSB0" board_id: int = 2 # 2=Cyton+Daisy simulate: bool = not BRAINFLOW_AVAILABLE # ─── Signal Processing ──────────────────────────────────────────────────────── class SignalProcessor: def __init__(self, srate: int = SAMPLE_RATE, n_channels: int = N_CHANNELS): self.srate = srate self.n_channels = n_channels self._design_filters() def _design_filters(self): """Pre-compute Butterworth filters""" nyq = self.srate / 2 self.bp_coefs = {} for name, (lo, hi) in [("broad", (1, 45)), ("alpha", ALPHA_BAND), ("beta", BETA_BAND), ("theta", THETA_BAND)]: b, a = butter(4, [lo/nyq, hi/nyq], btype="band") self.bp_coefs[name] = (b, a) # Notch 50/60 Hz b60, a60 = butter(4, [58/nyq, 62/nyq], btype="bandstop") b50, a50 = butter(4, [48/nyq, 52/nyq], btype="bandstop") self.notch_60 = (b60, a60) self.notch_50 = (b50, a50) def preprocess(self, epoch: np.ndarray) -> np.ndarray: """Band-pass + notch filter. epoch: (n_ch, n_samples)""" out = epoch.copy().astype(float) b, a = self.bp_coefs["broad"] bn60, an60 = self.notch_60 for ch in range(out.shape[0]): out[ch] = filtfilt(b, a, out[ch]) out[ch] = filtfilt(bn60, an60, out[ch]) # Remove DC out[ch] -= out[ch].mean() return out def compute_psd(self, epoch: np.ndarray, fmax: float = 50.0) -> Tuple[np.ndarray, np.ndarray]: """Welch PSD per channel. Returns (freqs, psd): psd shape (n_ch, n_freqs)""" nperseg = min(self.srate, epoch.shape[1]) freqs, psd = welch(epoch, fs=self.srate, nperseg=nperseg, axis=1) mask = freqs <= fmax return freqs[mask], psd[:, mask] def band_power(self, epoch: np.ndarray) -> Dict[str, np.ndarray]: """Band power per channel per band. Returns dict of (n_ch,) arrays""" freqs, psd = self.compute_psd(epoch) powers = {} for name, (lo, hi) in [("alpha", ALPHA_BAND), ("beta", BETA_BAND), ("theta", THETA_BAND), ("gamma", GAMMA_BAND)]: idx = np.logical_and(freqs >= lo, freqs <= hi) powers[name] = np.trapz(psd[:, idx], freqs[idx], axis=1) return powers def extract_features(self, epoch: np.ndarray) -> np.ndarray: """ Feature vector: band-power ratios + log-band-powers + covariance diagonal Returns 1D numpy array """ clean = self.preprocess(epoch) bp = self.band_power(clean) alpha = bp["alpha"] + 1e-10 beta = bp["beta"] + 1e-10 theta = bp["theta"] + 1e-10 gamma = bp["gamma"] + 1e-10 feats = [] feats.extend(np.log(alpha)) # 16 log-alpha feats.extend(np.log(beta)) # 16 log-beta feats.extend(np.log(theta)) # 16 log-theta feats.extend(beta / alpha) # 16 beta/alpha feats.extend(theta / alpha) # 16 theta/alpha # Covariance diagonal (channel variances) cov = np.cov(clean) feats.extend(np.log(np.diag(cov) + 1e-10)) # 16 log-var # Frontal asymmetry (not applicable for occipital, skip ratio index) return np.array(feats, dtype=float) def compute_topomap(self, epoch: np.ndarray) -> Dict[str, float]: """Return per-channel alpha power for topomap display""" clean = self.preprocess(epoch) bp = self.band_power(clean) return {name: float(val) for name, val in zip(CHANNEL_NAMES, bp["alpha"])} def artifact_rejection(self, epoch: np.ndarray, threshold_uv: float = 100.0) -> bool: """Returns True if epoch is clean""" peak_to_peak = epoch.max(axis=1) - epoch.min(axis=1) return bool(np.all(peak_to_peak < threshold_uv * 1e-6)) # BrainFlow in Volts # ─── Neurofeedback Engine ───────────────────────────────────────────────────── class NeurofeedbackEngine: """ Computes similarity between current epoch features and stored neural state centroids. Score ∈ [0,1] where 1 = perfectly matching the target state. """ def __init__(self): self.states: Dict[str, NeuralState] = {} def update_state(self, class_label: str, features_list: List[np.ndarray]): """Build/update stable state from a list of high-quality feature vectors""" if len(features_list) < 3: return mat = np.stack(features_list) centroid = mat.mean(axis=0) cov = np.cov(mat.T) + np.eye(mat.shape[1]) * 1e-6 # Convergence = 1 - normalised mean intra-cluster distance dists = [np.linalg.norm(f - centroid) for f in features_list] convergence = max(0.0, 1.0 - np.mean(dists) / (np.std(dists) + 1.0)) self.states[class_label] = NeuralState( class_label=class_label, centroid=centroid, covariance=cov, n_trials=len(features_list), convergence_score=float(convergence), ) logger.info(f"Neural state updated for {class_label}: convergence={convergence:.3f}") def compute_similarity(self, features: np.ndarray, class_label: str) -> float: """Mahalanobis-based similarity to target state""" if class_label not in self.states: return 0.0 state = self.states[class_label] diff = features - state.centroid try: inv_cov = np.linalg.pinv(state.covariance) dist = float(np.sqrt(diff @ inv_cov @ diff)) # Sigmoid mapping: dist=0 → score=1, dist=5 → score≈0.5 score = 1.0 / (1.0 + dist / 3.0) except Exception: score = 0.0 return float(np.clip(score, 0.0, 1.0)) def get_feedback_audio_params(self, score: float) -> dict: """Returns audio synthesis params based on NF score""" freq_hz = 200 + score * 600 # 200–800 Hz volume = 0.3 + score * 0.7 # 0.3–1.0 tone = "positive" if score > 0.65 else ("neutral" if score > 0.35 else "negative") return {"freq_hz": round(freq_hz, 1), "volume": round(volume, 3), "tone": tone} # ─── LDA Classifier ─────────────────────────────────────────────────────────── class EEGClassifier: def __init__(self): if SKLEARN_AVAILABLE: self.model = Pipeline([ ("scaler", StandardScaler()), ("lda", LinearDiscriminantAnalysis(solver="svd")), ]) self.is_trained = False self.classes_: List[str] = [] def fit(self, X: np.ndarray, y: List[str]): if not SKLEARN_AVAILABLE or len(X) < 6: return False self.model.fit(X, y) self.is_trained = True self.classes_ = list(np.unique(y)) logger.info(f"LDA trained on {len(X)} epochs, classes={self.classes_}") return True def predict_proba(self, features: np.ndarray) -> Dict[str, float]: if not self.is_trained: return {} proba = self.model.predict_proba(features.reshape(1, -1))[0] return {cls: float(p) for cls, p in zip(self.classes_, proba)} # ─── LSL Manager ────────────────────────────────────────────────────────────── class LSLManager: def __init__(self, stream_name: str = "EEGMarkers"): self.outlet = None self.eeg_outlet = None if LSL_AVAILABLE: self._init_marker_stream(stream_name) self._init_eeg_stream() def _init_marker_stream(self, name: str): info = StreamInfo(name, "Markers", 1, 0, cf_string, f"markers-{name}") info.desc().append_child_value("manufacturer", "EEG-MI-Backend") self.outlet = StreamOutlet(info) logger.info(f"LSL marker stream '{name}' opened") def _init_eeg_stream(self): info = StreamInfo("EEG-MI", "EEG", N_CHANNELS, SAMPLE_RATE, cf_float32, "eeg-mi") chns = info.desc().append_child("channels") for name in CHANNEL_NAMES: ch = chns.append_child("channel") ch.append_child_value("label", name) ch.append_child_value("unit", "microvolts") ch.append_child_value("type", "EEG") self.eeg_outlet = StreamOutlet(info, 32, 360) logger.info("LSL EEG stream opened") def push_marker(self, marker: str): if self.outlet: self.outlet.push_sample([marker]) logger.debug(f"LSL marker: {marker}") def push_eeg_chunk(self, chunk: np.ndarray): """chunk: (n_samples, n_channels)""" if self.eeg_outlet and chunk.size > 0: self.eeg_outlet.push_chunk(chunk.tolist()) # ─── Board Manager ──────────────────────────────────────────────────────────── class BoardManager: """Handles Cyton+Daisy board or simulation""" def __init__(self, config: SessionConfig): self.config = config self.board = None self.board_id = config.board_id self.srate = SAMPLE_RATE self._running = False self._thread: Optional[threading.Thread] = None self._buffer = deque(maxlen=BUFFER_SECONDS * SAMPLE_RATE) self._lock = threading.Lock() self.connected = False # Simulation state self._sim_t = 0.0 self._sim_phase = 0.0 def connect(self) -> bool: if self.config.simulate: logger.info("Board: SIMULATION mode active") self.connected = True return True if not BRAINFLOW_AVAILABLE: logger.error("BrainFlow not available; cannot connect to hardware") return False params = BrainFlowInputParams() params.serial_port = self.config.port try: BoardShim.enable_dev_board_logger() self.board = BoardShim(self.board_id, params) self.board.prepare_session() self.board.start_stream(45000) self.srate = BoardShim.get_sampling_rate(self.board_id) self.connected = True logger.info(f"Cyton+Daisy connected: {self.config.port} @ {self.srate} Hz") return True except BrainFlowError as e: logger.error(f"Board connection failed: {e}") return False def disconnect(self): self._running = False if self.board and BRAINFLOW_AVAILABLE: try: self.board.stop_stream() self.board.release_session() except Exception: pass self.connected = False def start_acquisition(self): self._running = True self._thread = threading.Thread(target=self._acquisition_loop, daemon=True) self._thread.start() def stop_acquisition(self): self._running = False def _acquisition_loop(self): while self._running: if self.config.simulate: chunk = self._generate_sim_chunk(50) # 50 samples at a time else: chunk = self._read_board_chunk() if chunk is not None and chunk.shape[1] > 0: with self._lock: for s in range(chunk.shape[1]): self._buffer.append(chunk[:, s]) time.sleep(0.05) def _read_board_chunk(self) -> Optional[np.ndarray]: try: data = self.board.get_board_data() eeg_channels = BoardShim.get_eeg_channels(self.board_id) if data.shape[1] == 0: return None return data[eeg_channels[:N_CHANNELS], :] # (16, n_samples) except Exception as e: logger.warning(f"Board read error: {e}") return None def _generate_sim_chunk(self, n_samples: int) -> np.ndarray: """Realistic EEG simulation: alpha rhythm + noise + ocular artifacts""" chunk = np.zeros((N_CHANNELS, n_samples)) t = np.linspace(self._sim_t, self._sim_t + n_samples/SAMPLE_RATE, n_samples) self._sim_t += n_samples / SAMPLE_RATE for ch in range(N_CHANNELS): # Alpha (8-12 Hz) with spatial gradient alpha_amp = (0.5 + 0.5 * math.sin(ch * 0.4)) * 15e-6 alpha = alpha_amp * np.sin(2 * np.pi * 10 * t + ch * 0.3) # Beta (13-25 Hz) beta = 5e-6 * np.sin(2 * np.pi * 20 * t + ch * 0.6) # Theta (4-8 Hz) – stronger in frontal (not this cap, but simulated) theta = 8e-6 * np.sin(2 * np.pi * 6 * t) # White noise noise = np.random.randn(n_samples) * 3e-6 chunk[ch] = alpha + beta + theta + noise # Occasional blink artifact on ch 0-1 if np.random.rand() < 0.02: idx = np.random.randint(0, max(1, n_samples - 20)) pulse = np.hanning(20) * 150e-6 end = min(idx + 20, n_samples) chunk[0, idx:end] += pulse[:end-idx] chunk[1, idx:end] += pulse[:end-idx] * 0.5 return chunk def get_epoch(self, duration: float, offset: float = 0.0) -> Optional[np.ndarray]: """Extract latest epoch of `duration` seconds. Returns (n_ch, n_samples)""" n_wanted = int((duration + offset) * SAMPLE_RATE) with self._lock: buf = list(self._buffer) if len(buf) < n_wanted: return None arr = np.array(buf[-n_wanted:]).T # (n_ch, n_samples) # Return only the epoch portion (skip offset) n_offset = int(offset * SAMPLE_RATE) return arr[:, n_offset:] def get_psd_snapshot(self) -> Optional[np.ndarray]: """Latest 2-second window for live display (n_ch, n_samples)""" return self.get_epoch(2.0) # ─── Session Manager ────────────────────────────────────────────────────────── class SessionManager: def __init__(self): self.config = SessionConfig() self.phase = SessionPhase.IDLE self.trials: List[Trial] = [] self.trial_counter = 0 self.current_class: Optional[str] = None self.target_class: Optional[str] = None self.processor = SignalProcessor() self.nf_engine = NeurofeedbackEngine() self.classifier = EEGClassifier() self.lsl = LSLManager() self.board = BoardManager(self.config) self._ws_clients: List[WebSocket] = [] self._event_queue: asyncio.Queue = asyncio.Queue() self._current_loop: Optional[asyncio.AbstractEventLoop] = None # ── Connection Management ────────────────────────────────────────────── def connect_board(self) -> dict: ok = self.board.connect() if ok: self.board.start_acquisition() self.lsl.push_marker("SESSION_START") return {"status": "connected" if ok else "error", "simulate": self.config.simulate} def disconnect_board(self): self.board.disconnect() self.lsl.push_marker("SESSION_END") # ── Trial Management ─────────────────────────────────────────────────── def start_trial(self, class_label: str) -> Trial: self.trial_counter += 1 onset = time.time() trial = Trial( trial_id=self.trial_counter, class_label=class_label, onset_time=onset, ) self.trials.append(trial) self.current_class = class_label self.phase = SessionPhase.ACQUISITION marker = f"TRIAL_START;class={class_label};id={self.trial_counter}" self.lsl.push_marker(marker) logger.info(f"Trial {self.trial_counter} started: {class_label}") return trial def end_trial(self, quality: int) -> dict: """Called after subject rates the trial 1-5""" active = self._get_active_trial() if not active: return {"error": "no active trial"} active.quality_score = quality self.lsl.push_marker(f"TRIAL_END;id={active.trial_id};quality={quality}") # Extract epoch (4s before now, skip first 0.5s baseline) epoch = self.board.get_epoch(EPOCH_DURATION, offset=0.5) if epoch is not None: # Artifact check clean = self.processor.artifact_rejection(epoch) if clean: active.eeg_epoch = epoch active.features = self.processor.extract_features(epoch) logger.info(f"Trial {active.trial_id}: clean epoch extracted, quality={quality}") else: logger.warning(f"Trial {active.trial_id}: artifact detected – epoch discarded") self.lsl.push_marker(f"ARTIFACT;id={active.trial_id}") # Update model if quality ≥ threshold self._maybe_update_model(active.class_label) # Compute NF score if model exists nf = 0.0 if active.features is not None: nf = self.nf_engine.compute_similarity(active.features, active.class_label) active.nf_score = nf return { "trial_id": active.trial_id, "quality": quality, "nf_score": round(nf, 4), "feedback": self.nf_engine.get_feedback_audio_params(nf), "model_updated": active.class_label in self.nf_engine.states, } def _get_active_trial(self) -> Optional[Trial]: for t in reversed(self.trials): if t.quality_score is None: return t return None def _maybe_update_model(self, class_label: str): """Rebuild neural state from high-quality trials""" good_features = [ t.features for t in self.trials if t.class_label == class_label and t.quality_score is not None and t.quality_score >= self.config.min_quality_for_model and t.features is not None ] if len(good_features) >= 3: self.nf_engine.update_state(class_label, good_features) # Retrain classifier if both classes have data self._retrain_classifier() def _retrain_classifier(self): X, y = [], [] for t in self.trials: if t.features is not None and t.quality_score is not None \ and t.quality_score >= self.config.min_quality_for_model: X.append(t.features) y.append(t.class_label) if len(set(y)) >= 2: self.classifier.fit(np.array(X), y) # ── Live Signals ─────────────────────────────────────────────────────── def get_live_signals(self) -> dict: epoch = self.board.get_psd_snapshot() if epoch is None: # Return synthetic zeros return { "channels": CHANNEL_NAMES, "alpha_power": [0.0] * N_CHANNELS, "beta_power": [0.0] * N_CHANNELS, "theta_power": [0.0] * N_CHANNELS, "raw_samples": [[0.0] * 50] * 4, # 4 channels preview "topomap": {n: 0.0 for n in CHANNEL_NAMES}, "signal_quality": [0.0] * N_CHANNELS, } try: clean = self.processor.preprocess(epoch) bp = self.processor.band_power(clean) topo = self.processor.compute_topomap(epoch) # Signal quality: inverse of variance relative to expected range peak2peak = (epoch.max(axis=1) - epoch.min(axis=1)) * 1e6 # µV quality = [float(np.clip(1.0 - (pp - 10) / 90.0, 0.0, 1.0)) for pp in peak2peak] # Raw traces for 4 selected channels (µV) sel = [0, 4, 8, 12] raw = (clean[sel, -50:] * 1e6).tolist() # last 200ms return { "channels": CHANNEL_NAMES, "alpha_power": [float(v * 1e12) for v in bp["alpha"]], "beta_power": [float(v * 1e12) for v in bp["beta"]], "theta_power": [float(v * 1e12) for v in bp["theta"]], "raw_samples": raw, "topomap": {k: float(v * 1e12) for k, v in topo.items()}, "signal_quality": quality, } except Exception as e: logger.warning(f"Live signal error: {e}") return {} # ── Session State ────────────────────────────────────────────────────── def get_state(self) -> dict: class_a_trials = [t.to_dict() for t in self.trials if t.class_label == self.config.class_a] class_b_trials = [t.to_dict() for t in self.trials if t.class_label == self.config.class_b] states = {} for label, state in self.nf_engine.states.items(): states[label] = { "n_trials": state.n_trials, "convergence_score": round(state.convergence_score, 4), "created_at": state.created_at, } return { "phase": self.phase.value, "config": { "subject_id": self.config.subject_id, "session_id": self.config.session_id, "class_a": self.config.class_a, "class_b": self.config.class_b, "simulate": self.config.simulate, "min_quality_for_model": self.config.min_quality_for_model, }, "class_a_trials": class_a_trials, "class_b_trials": class_b_trials, "neural_states": states, "classifier_trained": self.classifier.is_trained, "board_connected": self.board.connected, "total_trials": len(self.trials), } # ── WebSocket Broadcast ──────────────────────────────────────────────── def add_ws_client(self, ws: WebSocket): self._ws_clients.append(ws) def remove_ws_client(self, ws: WebSocket): if ws in self._ws_clients: self._ws_clients.remove(ws) async def broadcast(self, data: dict): dead = [] for ws in self._ws_clients: try: await ws.send_json(data) except Exception: dead.append(ws) for ws in dead: self.remove_ws_client(ws) async def live_broadcast_loop(self): """Continuously push live EEG data to all connected WS clients""" while True: await asyncio.sleep(0.1) # 10 Hz update if self._ws_clients and self.board.connected: live = self.get_live_signals() if live: live["type"] = "live_eeg" live["timestamp"] = time.time() await self.broadcast(live) # ─── FastAPI App ────────────────────────────────────────────────────────────── ROOT_DIR = Path(__file__).resolve().parent SITE_HTML = ROOT_DIR / "site.html" app = FastAPI(title="EEG Mental Imagery Backend", version="2.0.0") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) session = SessionManager() @app.get("/") def serve_ui(): """Serve the project UI (site.html) from the same origin as the API / WebSocket.""" if not SITE_HTML.is_file(): return JSONResponse( status_code=404, content={"error": "site.html not found", "path": str(SITE_HTML)}, ) return FileResponse(SITE_HTML, media_type="text/html; charset=utf-8") @app.on_event("startup") async def startup(): asyncio.create_task(session.live_broadcast_loop()) logger.info("EEG Backend started") # ── REST Endpoints ───────────────────────────────────────────────────────── @app.get("/health") def health(): return {"status": "ok", "brainflow": BRAINFLOW_AVAILABLE, "lsl": LSL_AVAILABLE, "sklearn": SKLEARN_AVAILABLE, "scipy": SCIPY_AVAILABLE} @app.post("/board/connect") def connect_board(port: str = "/dev/ttyUSB0", simulate: bool = False): session.config.port = port session.config.simulate = simulate or not BRAINFLOW_AVAILABLE return session.connect_board() @app.post("/board/disconnect") def disconnect_board(): session.disconnect_board() return {"status": "disconnected"} @app.post("/session/configure") def configure_session( subject_id: str = "S001", class_a: str = "Apple", class_b: str = "House", min_quality: int = 4 ): session.config.subject_id = subject_id session.config.class_a = class_a session.config.class_b = class_b session.config.min_quality_for_model = min_quality session.config.session_id = datetime.now().strftime("%Y%m%d_%H%M%S") return {"status": "configured", "config": asdict(session.config)} @app.post("/trial/start") def start_trial(class_label: str): trial = session.start_trial(class_label) return trial.to_dict() @app.post("/trial/end") def end_trial(quality: int): return session.end_trial(quality) @app.get("/state") def get_state(): return session.get_state() @app.get("/signals/live") def live_signals(): return session.get_live_signals() @app.get("/channels/info") def channel_info(): return { "n_channels": N_CHANNELS, "channel_names": CHANNEL_NAMES, "positions": ELECTRODE_POSITIONS_16CH, "sample_rate": SAMPLE_RATE, } @app.post("/lsl/marker") def push_marker(marker: str): session.lsl.push_marker(marker) return {"pushed": marker, "timestamp": time.time()} @app.get("/classifier/predict") def predict_current(): epoch = session.board.get_epoch(EPOCH_DURATION) if epoch is None: return {"error": "no data"} feats = session.processor.extract_features(epoch) proba = session.classifier.predict_proba(feats) nf_a = session.nf_engine.compute_similarity(feats, session.config.class_a) nf_b = session.nf_engine.compute_similarity(feats, session.config.class_b) return { "probabilities": proba, "nf_similarity": {session.config.class_a: nf_a, session.config.class_b: nf_b}, "timestamp": time.time(), } @app.post("/session/reset") def reset_session(): session.trials.clear() session.trial_counter = 0 session.nf_engine.states.clear() session.classifier.is_trained = False session.phase = SessionPhase.IDLE return {"status": "reset"} # ── WebSocket ────────────────────────────────────────────────────────────── @app.websocket("/ws") async def websocket_endpoint(ws: WebSocket): await ws.accept() session.add_ws_client(ws) logger.info(f"WebSocket client connected. Total: {len(session._ws_clients)}") try: while True: msg = await ws.receive_json() msg_type = msg.get("type", "") if msg_type == "ping": await ws.send_json({"type": "pong", "timestamp": time.time()}) elif msg_type == "state": await ws.send_json({"type": "state", **session.get_state()}) elif msg_type == "start_trial": trial = session.start_trial(msg["class_label"]) await ws.send_json({"type": "trial_started", **trial.to_dict()}) elif msg_type == "end_trial": result = session.end_trial(msg["quality"]) await ws.send_json({"type": "trial_ended", **result}) # Broadcast state update to all clients await session.broadcast({"type": "state_update", **session.get_state()}) elif msg_type == "configure": session.config.subject_id = msg.get("subject_id", session.config.subject_id) session.config.class_a = msg.get("class_a", session.config.class_a) session.config.class_b = msg.get("class_b", session.config.class_b) mq = msg.get("min_quality") if mq is not None: session.config.min_quality_for_model = int(mq) await ws.send_json({"type": "configured"}) elif msg_type == "connect_board": session.config.simulate = msg.get("simulate", True) if msg.get("port"): session.config.port = str(msg["port"]) result = session.connect_board() await ws.send_json({"type": "board_status", **result}) except WebSocketDisconnect: session.remove_ws_client(ws) logger.info(f"WebSocket client disconnected. Remaining: {len(session._ws_clients)}") # ─── Entry Point ────────────────────────────────────────────────────────────── if __name__ == "__main__": uvicorn.run("backend:app", host="0.0.0.0", port=8765, reload=False, log_level="info")