| """ |
| EEG Mental Imagery Classification Backend |
| Full-stack server: OpenBCI Cyton+Daisy (16ch), LSL markers, neurofeedback, websocket API |
| """ |
|
|
| import asyncio |
| import json |
| import logging |
| import math |
| import os |
| import queue |
| import threading |
| import time |
| from collections import deque |
| from dataclasses import asdict, dataclass, field |
| from datetime import datetime |
| from enum import Enum |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse, JSONResponse |
| import uvicorn |
|
|
| |
| try: |
| from brainflow.board_shim import BoardShim, BrainFlowInputParams, BoardIds, BrainFlowError |
| from brainflow.data_filter import DataFilter, FilterTypes, DetrendOperations, WindowOperations |
| BRAINFLOW_AVAILABLE = True |
| except ImportError: |
| BRAINFLOW_AVAILABLE = False |
| logging.warning("BrainFlow not installed β running in SIMULATION mode") |
|
|
| try: |
| from pylsl import StreamInfo, StreamOutlet, StreamInlet, resolve_streams, cf_string, cf_float32 |
| LSL_AVAILABLE = True |
| except ImportError: |
| LSL_AVAILABLE = False |
| logging.warning("pylsl not installed β LSL streaming disabled") |
|
|
| try: |
| import scipy.signal as signal |
| from scipy.signal import welch, butter, filtfilt |
| SCIPY_AVAILABLE = True |
| except ImportError: |
| SCIPY_AVAILABLE = False |
|
|
| try: |
| from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.pipeline import Pipeline |
| SKLEARN_AVAILABLE = True |
| except ImportError: |
| SKLEARN_AVAILABLE = False |
|
|
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s β %(levelname)s β %(message)s") |
| logger = logging.getLogger("EEG-Backend") |
|
|
| |
| SAMPLE_RATE = 250 |
| N_CHANNELS = 16 |
| EPOCH_DURATION = 4.0 |
| BASELINE_DURATION = 1.0 |
| BUFFER_SECONDS = 30 |
| ALPHA_BAND = (8, 13) |
| BETA_BAND = (13, 30) |
| THETA_BAND = (4, 8) |
| GAMMA_BAND = (30, 45) |
|
|
| |
| ELECTRODE_POSITIONS_16CH = { |
| "P3": (-0.30, 0.50), "Pz": (0.00, 0.58), "P4": (0.30, 0.50), |
| "P7": (-0.55, 0.42), "P8": (0.55, 0.42), |
| "PO3": (-0.22, 0.38), "POz": (0.00, 0.38), "PO4": (0.22, 0.38), |
| "PO7": (-0.40, 0.28), "PO8": (0.40, 0.28), |
| "O1": (-0.20, 0.18), "Oz": (0.00, 0.18), "O2": (0.20, 0.18), |
| "CP3": (-0.25, 0.68), "CPz": (0.00, 0.72), "CP4": (0.25, 0.68), |
| } |
| CHANNEL_NAMES = list(ELECTRODE_POSITIONS_16CH.keys()) |
|
|
| |
|
|
| class SessionPhase(str, Enum): |
| IDLE = "IDLE" |
| BASELINE = "BASELINE" |
| ACQUISITION = "ACQUISITION" |
| FEEDBACK = "FEEDBACK" |
| CONVERGENCE = "CONVERGENCE" |
| LIBRARY = "LIBRARY" |
|
|
| @dataclass |
| class Trial: |
| trial_id: int |
| class_label: str |
| onset_time: float |
| quality_score: Optional[int] = None |
| eeg_epoch: Optional[np.ndarray] = None |
| features: Optional[np.ndarray] = None |
| nf_score: Optional[float] = None |
| timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) |
|
|
| def to_dict(self) -> dict: |
| return { |
| "trial_id": self.trial_id, |
| "class_label": self.class_label, |
| "onset_time": self.onset_time, |
| "quality_score": self.quality_score, |
| "nf_score": self.nf_score, |
| "timestamp": self.timestamp, |
| } |
|
|
| @dataclass |
| class NeuralState: |
| """Stable neural state = cluster centroid in feature space""" |
| class_label: str |
| centroid: np.ndarray |
| covariance: np.ndarray |
| n_trials: int |
| convergence_score: float |
| created_at: str = field(default_factory=lambda: datetime.now().isoformat()) |
|
|
| @dataclass |
| class SessionConfig: |
| subject_id: str = "S001" |
| session_id: str = field(default_factory=lambda: datetime.now().strftime("%Y%m%d_%H%M%S")) |
| class_a: str = "Class_A" |
| class_b: str = "Class_B" |
| n_trials_per_class: int = 20 |
| min_quality_for_model: int = 4 |
| feedback_threshold: float = 0.65 |
| port: str = "/dev/ttyUSB0" |
| board_id: int = 2 |
| simulate: bool = not BRAINFLOW_AVAILABLE |
|
|
| |
|
|
| class SignalProcessor: |
| def __init__(self, srate: int = SAMPLE_RATE, n_channels: int = N_CHANNELS): |
| self.srate = srate |
| self.n_channels = n_channels |
| self._design_filters() |
|
|
| def _design_filters(self): |
| """Pre-compute Butterworth filters""" |
| nyq = self.srate / 2 |
| self.bp_coefs = {} |
| for name, (lo, hi) in [("broad", (1, 45)), ("alpha", ALPHA_BAND), |
| ("beta", BETA_BAND), ("theta", THETA_BAND)]: |
| b, a = butter(4, [lo/nyq, hi/nyq], btype="band") |
| self.bp_coefs[name] = (b, a) |
| |
| b60, a60 = butter(4, [58/nyq, 62/nyq], btype="bandstop") |
| b50, a50 = butter(4, [48/nyq, 52/nyq], btype="bandstop") |
| self.notch_60 = (b60, a60) |
| self.notch_50 = (b50, a50) |
|
|
| def preprocess(self, epoch: np.ndarray) -> np.ndarray: |
| """Band-pass + notch filter. epoch: (n_ch, n_samples)""" |
| out = epoch.copy().astype(float) |
| b, a = self.bp_coefs["broad"] |
| bn60, an60 = self.notch_60 |
| for ch in range(out.shape[0]): |
| out[ch] = filtfilt(b, a, out[ch]) |
| out[ch] = filtfilt(bn60, an60, out[ch]) |
| |
| out[ch] -= out[ch].mean() |
| return out |
|
|
| def compute_psd(self, epoch: np.ndarray, fmax: float = 50.0) -> Tuple[np.ndarray, np.ndarray]: |
| """Welch PSD per channel. Returns (freqs, psd): psd shape (n_ch, n_freqs)""" |
| nperseg = min(self.srate, epoch.shape[1]) |
| freqs, psd = welch(epoch, fs=self.srate, nperseg=nperseg, axis=1) |
| mask = freqs <= fmax |
| return freqs[mask], psd[:, mask] |
|
|
| def band_power(self, epoch: np.ndarray) -> Dict[str, np.ndarray]: |
| """Band power per channel per band. Returns dict of (n_ch,) arrays""" |
| freqs, psd = self.compute_psd(epoch) |
| powers = {} |
| for name, (lo, hi) in [("alpha", ALPHA_BAND), ("beta", BETA_BAND), |
| ("theta", THETA_BAND), ("gamma", GAMMA_BAND)]: |
| idx = np.logical_and(freqs >= lo, freqs <= hi) |
| powers[name] = np.trapz(psd[:, idx], freqs[idx], axis=1) |
| return powers |
|
|
| def extract_features(self, epoch: np.ndarray) -> np.ndarray: |
| """ |
| Feature vector: band-power ratios + log-band-powers + covariance diagonal |
| Returns 1D numpy array |
| """ |
| clean = self.preprocess(epoch) |
| bp = self.band_power(clean) |
| alpha = bp["alpha"] + 1e-10 |
| beta = bp["beta"] + 1e-10 |
| theta = bp["theta"] + 1e-10 |
| gamma = bp["gamma"] + 1e-10 |
|
|
| feats = [] |
| feats.extend(np.log(alpha)) |
| feats.extend(np.log(beta)) |
| feats.extend(np.log(theta)) |
| feats.extend(beta / alpha) |
| feats.extend(theta / alpha) |
| |
| cov = np.cov(clean) |
| feats.extend(np.log(np.diag(cov) + 1e-10)) |
|
|
| |
| return np.array(feats, dtype=float) |
|
|
| def compute_topomap(self, epoch: np.ndarray) -> Dict[str, float]: |
| """Return per-channel alpha power for topomap display""" |
| clean = self.preprocess(epoch) |
| bp = self.band_power(clean) |
| return {name: float(val) for name, val in zip(CHANNEL_NAMES, bp["alpha"])} |
|
|
| def artifact_rejection(self, epoch: np.ndarray, threshold_uv: float = 100.0) -> bool: |
| """Returns True if epoch is clean""" |
| peak_to_peak = epoch.max(axis=1) - epoch.min(axis=1) |
| return bool(np.all(peak_to_peak < threshold_uv * 1e-6)) |
|
|
| |
|
|
| class NeurofeedbackEngine: |
| """ |
| Computes similarity between current epoch features and stored neural state centroids. |
| Score β [0,1] where 1 = perfectly matching the target state. |
| """ |
| def __init__(self): |
| self.states: Dict[str, NeuralState] = {} |
|
|
| def update_state(self, class_label: str, features_list: List[np.ndarray]): |
| """Build/update stable state from a list of high-quality feature vectors""" |
| if len(features_list) < 3: |
| return |
| mat = np.stack(features_list) |
| centroid = mat.mean(axis=0) |
| cov = np.cov(mat.T) + np.eye(mat.shape[1]) * 1e-6 |
| |
| dists = [np.linalg.norm(f - centroid) for f in features_list] |
| convergence = max(0.0, 1.0 - np.mean(dists) / (np.std(dists) + 1.0)) |
| self.states[class_label] = NeuralState( |
| class_label=class_label, |
| centroid=centroid, |
| covariance=cov, |
| n_trials=len(features_list), |
| convergence_score=float(convergence), |
| ) |
| logger.info(f"Neural state updated for {class_label}: convergence={convergence:.3f}") |
|
|
| def compute_similarity(self, features: np.ndarray, class_label: str) -> float: |
| """Mahalanobis-based similarity to target state""" |
| if class_label not in self.states: |
| return 0.0 |
| state = self.states[class_label] |
| diff = features - state.centroid |
| try: |
| inv_cov = np.linalg.pinv(state.covariance) |
| dist = float(np.sqrt(diff @ inv_cov @ diff)) |
| |
| score = 1.0 / (1.0 + dist / 3.0) |
| except Exception: |
| score = 0.0 |
| return float(np.clip(score, 0.0, 1.0)) |
|
|
| def get_feedback_audio_params(self, score: float) -> dict: |
| """Returns audio synthesis params based on NF score""" |
| freq_hz = 200 + score * 600 |
| volume = 0.3 + score * 0.7 |
| tone = "positive" if score > 0.65 else ("neutral" if score > 0.35 else "negative") |
| return {"freq_hz": round(freq_hz, 1), "volume": round(volume, 3), "tone": tone} |
|
|
| |
|
|
| class EEGClassifier: |
| def __init__(self): |
| if SKLEARN_AVAILABLE: |
| self.model = Pipeline([ |
| ("scaler", StandardScaler()), |
| ("lda", LinearDiscriminantAnalysis(solver="svd")), |
| ]) |
| self.is_trained = False |
| self.classes_: List[str] = [] |
|
|
| def fit(self, X: np.ndarray, y: List[str]): |
| if not SKLEARN_AVAILABLE or len(X) < 6: |
| return False |
| self.model.fit(X, y) |
| self.is_trained = True |
| self.classes_ = list(np.unique(y)) |
| logger.info(f"LDA trained on {len(X)} epochs, classes={self.classes_}") |
| return True |
|
|
| def predict_proba(self, features: np.ndarray) -> Dict[str, float]: |
| if not self.is_trained: |
| return {} |
| proba = self.model.predict_proba(features.reshape(1, -1))[0] |
| return {cls: float(p) for cls, p in zip(self.classes_, proba)} |
|
|
| |
|
|
| class LSLManager: |
| def __init__(self, stream_name: str = "EEGMarkers"): |
| self.outlet = None |
| self.eeg_outlet = None |
| if LSL_AVAILABLE: |
| self._init_marker_stream(stream_name) |
| self._init_eeg_stream() |
|
|
| def _init_marker_stream(self, name: str): |
| info = StreamInfo(name, "Markers", 1, 0, cf_string, f"markers-{name}") |
| info.desc().append_child_value("manufacturer", "EEG-MI-Backend") |
| self.outlet = StreamOutlet(info) |
| logger.info(f"LSL marker stream '{name}' opened") |
|
|
| def _init_eeg_stream(self): |
| info = StreamInfo("EEG-MI", "EEG", N_CHANNELS, SAMPLE_RATE, cf_float32, "eeg-mi") |
| chns = info.desc().append_child("channels") |
| for name in CHANNEL_NAMES: |
| ch = chns.append_child("channel") |
| ch.append_child_value("label", name) |
| ch.append_child_value("unit", "microvolts") |
| ch.append_child_value("type", "EEG") |
| self.eeg_outlet = StreamOutlet(info, 32, 360) |
| logger.info("LSL EEG stream opened") |
|
|
| def push_marker(self, marker: str): |
| if self.outlet: |
| self.outlet.push_sample([marker]) |
| logger.debug(f"LSL marker: {marker}") |
|
|
| def push_eeg_chunk(self, chunk: np.ndarray): |
| """chunk: (n_samples, n_channels)""" |
| if self.eeg_outlet and chunk.size > 0: |
| self.eeg_outlet.push_chunk(chunk.tolist()) |
|
|
| |
|
|
| class BoardManager: |
| """Handles Cyton+Daisy board or simulation""" |
|
|
| def __init__(self, config: SessionConfig): |
| self.config = config |
| self.board = None |
| self.board_id = config.board_id |
| self.srate = SAMPLE_RATE |
| self._running = False |
| self._thread: Optional[threading.Thread] = None |
| self._buffer = deque(maxlen=BUFFER_SECONDS * SAMPLE_RATE) |
| self._lock = threading.Lock() |
| self.connected = False |
|
|
| |
| self._sim_t = 0.0 |
| self._sim_phase = 0.0 |
|
|
| def connect(self) -> bool: |
| if self.config.simulate: |
| logger.info("Board: SIMULATION mode active") |
| self.connected = True |
| return True |
|
|
| if not BRAINFLOW_AVAILABLE: |
| logger.error("BrainFlow not available; cannot connect to hardware") |
| return False |
|
|
| params = BrainFlowInputParams() |
| params.serial_port = self.config.port |
| try: |
| BoardShim.enable_dev_board_logger() |
| self.board = BoardShim(self.board_id, params) |
| self.board.prepare_session() |
| self.board.start_stream(45000) |
| self.srate = BoardShim.get_sampling_rate(self.board_id) |
| self.connected = True |
| logger.info(f"Cyton+Daisy connected: {self.config.port} @ {self.srate} Hz") |
| return True |
| except BrainFlowError as e: |
| logger.error(f"Board connection failed: {e}") |
| return False |
|
|
| def disconnect(self): |
| self._running = False |
| if self.board and BRAINFLOW_AVAILABLE: |
| try: |
| self.board.stop_stream() |
| self.board.release_session() |
| except Exception: |
| pass |
| self.connected = False |
|
|
| def start_acquisition(self): |
| self._running = True |
| self._thread = threading.Thread(target=self._acquisition_loop, daemon=True) |
| self._thread.start() |
|
|
| def stop_acquisition(self): |
| self._running = False |
|
|
| def _acquisition_loop(self): |
| while self._running: |
| if self.config.simulate: |
| chunk = self._generate_sim_chunk(50) |
| else: |
| chunk = self._read_board_chunk() |
|
|
| if chunk is not None and chunk.shape[1] > 0: |
| with self._lock: |
| for s in range(chunk.shape[1]): |
| self._buffer.append(chunk[:, s]) |
| time.sleep(0.05) |
|
|
| def _read_board_chunk(self) -> Optional[np.ndarray]: |
| try: |
| data = self.board.get_board_data() |
| eeg_channels = BoardShim.get_eeg_channels(self.board_id) |
| if data.shape[1] == 0: |
| return None |
| return data[eeg_channels[:N_CHANNELS], :] |
| except Exception as e: |
| logger.warning(f"Board read error: {e}") |
| return None |
|
|
| def _generate_sim_chunk(self, n_samples: int) -> np.ndarray: |
| """Realistic EEG simulation: alpha rhythm + noise + ocular artifacts""" |
| chunk = np.zeros((N_CHANNELS, n_samples)) |
| t = np.linspace(self._sim_t, self._sim_t + n_samples/SAMPLE_RATE, n_samples) |
| self._sim_t += n_samples / SAMPLE_RATE |
|
|
| for ch in range(N_CHANNELS): |
| |
| alpha_amp = (0.5 + 0.5 * math.sin(ch * 0.4)) * 15e-6 |
| alpha = alpha_amp * np.sin(2 * np.pi * 10 * t + ch * 0.3) |
| |
| beta = 5e-6 * np.sin(2 * np.pi * 20 * t + ch * 0.6) |
| |
| theta = 8e-6 * np.sin(2 * np.pi * 6 * t) |
| |
| noise = np.random.randn(n_samples) * 3e-6 |
| chunk[ch] = alpha + beta + theta + noise |
|
|
| |
| if np.random.rand() < 0.02: |
| idx = np.random.randint(0, max(1, n_samples - 20)) |
| pulse = np.hanning(20) * 150e-6 |
| end = min(idx + 20, n_samples) |
| chunk[0, idx:end] += pulse[:end-idx] |
| chunk[1, idx:end] += pulse[:end-idx] * 0.5 |
|
|
| return chunk |
|
|
| def get_epoch(self, duration: float, offset: float = 0.0) -> Optional[np.ndarray]: |
| """Extract latest epoch of `duration` seconds. Returns (n_ch, n_samples)""" |
| n_wanted = int((duration + offset) * SAMPLE_RATE) |
| with self._lock: |
| buf = list(self._buffer) |
| if len(buf) < n_wanted: |
| return None |
| arr = np.array(buf[-n_wanted:]).T |
| |
| n_offset = int(offset * SAMPLE_RATE) |
| return arr[:, n_offset:] |
|
|
| def get_psd_snapshot(self) -> Optional[np.ndarray]: |
| """Latest 2-second window for live display (n_ch, n_samples)""" |
| return self.get_epoch(2.0) |
|
|
| |
|
|
| class SessionManager: |
| def __init__(self): |
| self.config = SessionConfig() |
| self.phase = SessionPhase.IDLE |
| self.trials: List[Trial] = [] |
| self.trial_counter = 0 |
| self.current_class: Optional[str] = None |
| self.target_class: Optional[str] = None |
|
|
| self.processor = SignalProcessor() |
| self.nf_engine = NeurofeedbackEngine() |
| self.classifier = EEGClassifier() |
| self.lsl = LSLManager() |
| self.board = BoardManager(self.config) |
|
|
| self._ws_clients: List[WebSocket] = [] |
| self._event_queue: asyncio.Queue = asyncio.Queue() |
| self._current_loop: Optional[asyncio.AbstractEventLoop] = None |
|
|
| |
|
|
| def connect_board(self) -> dict: |
| ok = self.board.connect() |
| if ok: |
| self.board.start_acquisition() |
| self.lsl.push_marker("SESSION_START") |
| return {"status": "connected" if ok else "error", "simulate": self.config.simulate} |
|
|
| def disconnect_board(self): |
| self.board.disconnect() |
| self.lsl.push_marker("SESSION_END") |
|
|
| |
|
|
| def start_trial(self, class_label: str) -> Trial: |
| self.trial_counter += 1 |
| onset = time.time() |
| trial = Trial( |
| trial_id=self.trial_counter, |
| class_label=class_label, |
| onset_time=onset, |
| ) |
| self.trials.append(trial) |
| self.current_class = class_label |
| self.phase = SessionPhase.ACQUISITION |
|
|
| marker = f"TRIAL_START;class={class_label};id={self.trial_counter}" |
| self.lsl.push_marker(marker) |
| logger.info(f"Trial {self.trial_counter} started: {class_label}") |
| return trial |
|
|
| def end_trial(self, quality: int) -> dict: |
| """Called after subject rates the trial 1-5""" |
| active = self._get_active_trial() |
| if not active: |
| return {"error": "no active trial"} |
|
|
| active.quality_score = quality |
| self.lsl.push_marker(f"TRIAL_END;id={active.trial_id};quality={quality}") |
|
|
| |
| epoch = self.board.get_epoch(EPOCH_DURATION, offset=0.5) |
| if epoch is not None: |
| |
| clean = self.processor.artifact_rejection(epoch) |
| if clean: |
| active.eeg_epoch = epoch |
| active.features = self.processor.extract_features(epoch) |
| logger.info(f"Trial {active.trial_id}: clean epoch extracted, quality={quality}") |
| else: |
| logger.warning(f"Trial {active.trial_id}: artifact detected β epoch discarded") |
| self.lsl.push_marker(f"ARTIFACT;id={active.trial_id}") |
|
|
| |
| self._maybe_update_model(active.class_label) |
|
|
| |
| nf = 0.0 |
| if active.features is not None: |
| nf = self.nf_engine.compute_similarity(active.features, active.class_label) |
| active.nf_score = nf |
|
|
| return { |
| "trial_id": active.trial_id, |
| "quality": quality, |
| "nf_score": round(nf, 4), |
| "feedback": self.nf_engine.get_feedback_audio_params(nf), |
| "model_updated": active.class_label in self.nf_engine.states, |
| } |
|
|
| def _get_active_trial(self) -> Optional[Trial]: |
| for t in reversed(self.trials): |
| if t.quality_score is None: |
| return t |
| return None |
|
|
| def _maybe_update_model(self, class_label: str): |
| """Rebuild neural state from high-quality trials""" |
| good_features = [ |
| t.features for t in self.trials |
| if t.class_label == class_label |
| and t.quality_score is not None |
| and t.quality_score >= self.config.min_quality_for_model |
| and t.features is not None |
| ] |
| if len(good_features) >= 3: |
| self.nf_engine.update_state(class_label, good_features) |
| |
| self._retrain_classifier() |
|
|
| def _retrain_classifier(self): |
| X, y = [], [] |
| for t in self.trials: |
| if t.features is not None and t.quality_score is not None \ |
| and t.quality_score >= self.config.min_quality_for_model: |
| X.append(t.features) |
| y.append(t.class_label) |
| if len(set(y)) >= 2: |
| self.classifier.fit(np.array(X), y) |
|
|
| |
|
|
| def get_live_signals(self) -> dict: |
| epoch = self.board.get_psd_snapshot() |
| if epoch is None: |
| |
| return { |
| "channels": CHANNEL_NAMES, |
| "alpha_power": [0.0] * N_CHANNELS, |
| "beta_power": [0.0] * N_CHANNELS, |
| "theta_power": [0.0] * N_CHANNELS, |
| "raw_samples": [[0.0] * 50] * 4, |
| "topomap": {n: 0.0 for n in CHANNEL_NAMES}, |
| "signal_quality": [0.0] * N_CHANNELS, |
| } |
|
|
| try: |
| clean = self.processor.preprocess(epoch) |
| bp = self.processor.band_power(clean) |
| topo = self.processor.compute_topomap(epoch) |
|
|
| |
| peak2peak = (epoch.max(axis=1) - epoch.min(axis=1)) * 1e6 |
| quality = [float(np.clip(1.0 - (pp - 10) / 90.0, 0.0, 1.0)) for pp in peak2peak] |
|
|
| |
| sel = [0, 4, 8, 12] |
| raw = (clean[sel, -50:] * 1e6).tolist() |
|
|
| return { |
| "channels": CHANNEL_NAMES, |
| "alpha_power": [float(v * 1e12) for v in bp["alpha"]], |
| "beta_power": [float(v * 1e12) for v in bp["beta"]], |
| "theta_power": [float(v * 1e12) for v in bp["theta"]], |
| "raw_samples": raw, |
| "topomap": {k: float(v * 1e12) for k, v in topo.items()}, |
| "signal_quality": quality, |
| } |
| except Exception as e: |
| logger.warning(f"Live signal error: {e}") |
| return {} |
|
|
| |
|
|
| def get_state(self) -> dict: |
| class_a_trials = [t.to_dict() for t in self.trials if t.class_label == self.config.class_a] |
| class_b_trials = [t.to_dict() for t in self.trials if t.class_label == self.config.class_b] |
|
|
| states = {} |
| for label, state in self.nf_engine.states.items(): |
| states[label] = { |
| "n_trials": state.n_trials, |
| "convergence_score": round(state.convergence_score, 4), |
| "created_at": state.created_at, |
| } |
|
|
| return { |
| "phase": self.phase.value, |
| "config": { |
| "subject_id": self.config.subject_id, |
| "session_id": self.config.session_id, |
| "class_a": self.config.class_a, |
| "class_b": self.config.class_b, |
| "simulate": self.config.simulate, |
| "min_quality_for_model": self.config.min_quality_for_model, |
| }, |
| "class_a_trials": class_a_trials, |
| "class_b_trials": class_b_trials, |
| "neural_states": states, |
| "classifier_trained": self.classifier.is_trained, |
| "board_connected": self.board.connected, |
| "total_trials": len(self.trials), |
| } |
|
|
| |
|
|
| def add_ws_client(self, ws: WebSocket): |
| self._ws_clients.append(ws) |
|
|
| def remove_ws_client(self, ws: WebSocket): |
| if ws in self._ws_clients: |
| self._ws_clients.remove(ws) |
|
|
| async def broadcast(self, data: dict): |
| dead = [] |
| for ws in self._ws_clients: |
| try: |
| await ws.send_json(data) |
| except Exception: |
| dead.append(ws) |
| for ws in dead: |
| self.remove_ws_client(ws) |
|
|
| async def live_broadcast_loop(self): |
| """Continuously push live EEG data to all connected WS clients""" |
| while True: |
| await asyncio.sleep(0.1) |
| if self._ws_clients and self.board.connected: |
| live = self.get_live_signals() |
| if live: |
| live["type"] = "live_eeg" |
| live["timestamp"] = time.time() |
| await self.broadcast(live) |
|
|
| |
|
|
| ROOT_DIR = Path(__file__).resolve().parent |
| SITE_HTML = ROOT_DIR / "site.html" |
|
|
| app = FastAPI(title="EEG Mental Imagery Backend", version="2.0.0") |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| session = SessionManager() |
|
|
|
|
| @app.get("/") |
| def serve_ui(): |
| """Serve the project UI (site.html) from the same origin as the API / WebSocket.""" |
| if not SITE_HTML.is_file(): |
| return JSONResponse( |
| status_code=404, |
| content={"error": "site.html not found", "path": str(SITE_HTML)}, |
| ) |
| return FileResponse(SITE_HTML, media_type="text/html; charset=utf-8") |
|
|
| @app.on_event("startup") |
| async def startup(): |
| asyncio.create_task(session.live_broadcast_loop()) |
| logger.info("EEG Backend started") |
|
|
| |
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "brainflow": BRAINFLOW_AVAILABLE, "lsl": LSL_AVAILABLE, |
| "sklearn": SKLEARN_AVAILABLE, "scipy": SCIPY_AVAILABLE} |
|
|
| @app.post("/board/connect") |
| def connect_board(port: str = "/dev/ttyUSB0", simulate: bool = False): |
| session.config.port = port |
| session.config.simulate = simulate or not BRAINFLOW_AVAILABLE |
| return session.connect_board() |
|
|
| @app.post("/board/disconnect") |
| def disconnect_board(): |
| session.disconnect_board() |
| return {"status": "disconnected"} |
|
|
| @app.post("/session/configure") |
| def configure_session( |
| subject_id: str = "S001", |
| class_a: str = "Apple", |
| class_b: str = "House", |
| min_quality: int = 4 |
| ): |
| session.config.subject_id = subject_id |
| session.config.class_a = class_a |
| session.config.class_b = class_b |
| session.config.min_quality_for_model = min_quality |
| session.config.session_id = datetime.now().strftime("%Y%m%d_%H%M%S") |
| return {"status": "configured", "config": asdict(session.config)} |
|
|
| @app.post("/trial/start") |
| def start_trial(class_label: str): |
| trial = session.start_trial(class_label) |
| return trial.to_dict() |
|
|
| @app.post("/trial/end") |
| def end_trial(quality: int): |
| return session.end_trial(quality) |
|
|
| @app.get("/state") |
| def get_state(): |
| return session.get_state() |
|
|
| @app.get("/signals/live") |
| def live_signals(): |
| return session.get_live_signals() |
|
|
| @app.get("/channels/info") |
| def channel_info(): |
| return { |
| "n_channels": N_CHANNELS, |
| "channel_names": CHANNEL_NAMES, |
| "positions": ELECTRODE_POSITIONS_16CH, |
| "sample_rate": SAMPLE_RATE, |
| } |
|
|
| @app.post("/lsl/marker") |
| def push_marker(marker: str): |
| session.lsl.push_marker(marker) |
| return {"pushed": marker, "timestamp": time.time()} |
|
|
| @app.get("/classifier/predict") |
| def predict_current(): |
| epoch = session.board.get_epoch(EPOCH_DURATION) |
| if epoch is None: |
| return {"error": "no data"} |
| feats = session.processor.extract_features(epoch) |
| proba = session.classifier.predict_proba(feats) |
| nf_a = session.nf_engine.compute_similarity(feats, session.config.class_a) |
| nf_b = session.nf_engine.compute_similarity(feats, session.config.class_b) |
| return { |
| "probabilities": proba, |
| "nf_similarity": {session.config.class_a: nf_a, session.config.class_b: nf_b}, |
| "timestamp": time.time(), |
| } |
|
|
| @app.post("/session/reset") |
| def reset_session(): |
| session.trials.clear() |
| session.trial_counter = 0 |
| session.nf_engine.states.clear() |
| session.classifier.is_trained = False |
| session.phase = SessionPhase.IDLE |
| return {"status": "reset"} |
|
|
| |
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(ws: WebSocket): |
| await ws.accept() |
| session.add_ws_client(ws) |
| logger.info(f"WebSocket client connected. Total: {len(session._ws_clients)}") |
| try: |
| while True: |
| msg = await ws.receive_json() |
| msg_type = msg.get("type", "") |
|
|
| if msg_type == "ping": |
| await ws.send_json({"type": "pong", "timestamp": time.time()}) |
|
|
| elif msg_type == "state": |
| await ws.send_json({"type": "state", **session.get_state()}) |
|
|
| elif msg_type == "start_trial": |
| trial = session.start_trial(msg["class_label"]) |
| await ws.send_json({"type": "trial_started", **trial.to_dict()}) |
|
|
| elif msg_type == "end_trial": |
| result = session.end_trial(msg["quality"]) |
| await ws.send_json({"type": "trial_ended", **result}) |
| |
| await session.broadcast({"type": "state_update", **session.get_state()}) |
|
|
| elif msg_type == "configure": |
| session.config.subject_id = msg.get("subject_id", session.config.subject_id) |
| session.config.class_a = msg.get("class_a", session.config.class_a) |
| session.config.class_b = msg.get("class_b", session.config.class_b) |
| mq = msg.get("min_quality") |
| if mq is not None: |
| session.config.min_quality_for_model = int(mq) |
| await ws.send_json({"type": "configured"}) |
|
|
| elif msg_type == "connect_board": |
| session.config.simulate = msg.get("simulate", True) |
| if msg.get("port"): |
| session.config.port = str(msg["port"]) |
| result = session.connect_board() |
| await ws.send_json({"type": "board_status", **result}) |
|
|
| except WebSocketDisconnect: |
| session.remove_ws_client(ws) |
| logger.info(f"WebSocket client disconnected. Remaining: {len(session._ws_clients)}") |
|
|
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run("backend:app", host="0.0.0.0", port=8765, reload=False, log_level="info") |