"""Minimal FALCON M1 decoder wrapper for CortexFM (reference implementation). This is a self-contained reference that mirrors `src/cortex_fm/eval/falcon_m1_decoder.py` from the source repository, intended to be uploaded alongside the Hugging Face checkpoint so downstream users can evaluate `epoch28-0.2599.ckpt` on the FALCON M1 benchmark without cloning the full project. Dependencies ------------ - torch >= 2.10 - pytorch_lightning - numpy - falcon-challenge == 1.0.2 - cortex_fm (installed from the GitHub companion repository or PyPI mirror) The wrapper implements `falcon_challenge.interface.BCIDecoder` for the M1 task: * input : per-timestep spike counts, shape (batch_size, 64) * output : per-timestep EMG envelope predictions, shape (batch_size, 16) The encoder is a 64-bin (1.28 s) window model. The FALCON evaluator streams data one bin at a time, so this wrapper maintains a rolling 64-bin spike buffer per batch slot and returns the last-bin EMG prediction. For session transitions (`reset(dataset_tags)`), the buffer is zeroed per slot. This honors the "no history leakage across sessions" contract of the FALCON continual-evaluation protocol. Example ------- from huggingface_hub import hf_hub_download from falcon_challenge.evaluator import FalconEvaluator from falcon_challenge.config import FalconConfig, FalconTask ckpt = hf_hub_download( repo_id="/CortexFM", filename="epoch28-0.2599.ckpt", ) decoder = CortexFMFalconDecoder(ckpt_path=ckpt, batch_size=4) evaluator = FalconEvaluator( eval_remote=False, split="m1", verbose=True, ) metrics = evaluator.evaluate(decoder, phase="minival") print(metrics) See the upstream `cortex_fm.eval.run_falcon_m1` script in the project's GitHub repository for a complete CLI harness including auxiliary co-bps computation. """ from __future__ import annotations from pathlib import Path from typing import List import numpy as np import torch from torch import Tensor from falcon_challenge.config import FalconConfig, FalconTask from falcon_challenge.interface import BCIDecoder from cortex_fm.training import CortexFMPretrainModule class CortexFMFalconDecoder(BCIDecoder): """FALCON M1 decoder backed by a pretrained CortexFM checkpoint. Parameters ---------- ckpt_path : str | Path Lightning checkpoint produced by `cortex_fm.run_pretrain`. For the public release this is the file `epoch28-0.2599.ckpt` downloaded from Hugging Face Hub. task_config : FalconConfig | None Defaults to `FalconConfig(task=FalconTask.m1, bin_size_ms=20)`. batch_size : int Must match the evaluator batch size (M1 recommended: 4). device : str "cuda" if available else "cpu". context_T : int Rolling window length in 20 ms bins. Should match pretraining context (64, i.e. 1.28 s). """ def __init__( self, ckpt_path: str | Path, task_config: FalconConfig | None = None, batch_size: int = 4, device: str = "cuda" if torch.cuda.is_available() else "cpu", context_T: int = 64, ): if task_config is None: task_config = FalconConfig(task=FalconTask.m1, bin_size_ms=20) if task_config.task != FalconTask.m1: raise ValueError( f"CortexFMFalconDecoder targets FalconTask.m1; got {task_config.task}" ) super().__init__(task_config=task_config, batch_size=batch_size) self.ckpt_path = Path(ckpt_path).expanduser().resolve() if not self.ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {self.ckpt_path}") self.device = torch.device(device) self.context_T = int(context_T) # Load Lightning module (strict so we catch silent state_dict drift) self.module = CortexFMPretrainModule.load_from_checkpoint( str(self.ckpt_path), map_location=self.device, strict=True, ) self.module.to(self.device) self.module.eval() for p in self.module.parameters(): p.requires_grad_(False) # Task dimensions self.n_channels: int = int(task_config.n_channels) # 64 for M1 self.out_dim: int = int(task_config.out_dim) # 16 for M1 # Rolling spike buffers (batch_size, context_T, n_channels), int self._buffer: Tensor | None = None self._buffer_fill: Tensor | None = None self._tags: List[str] = [] self._allocate_buffers(self.batch_size) # ---------------- public API (FALCON) ---------------- def reset(self, dataset_tags: List[str] = [""]) -> None: """Reset rolling buffers (called when at least one slot switches file).""" if len(dataset_tags) != self.batch_size: self.set_batch_size(len(dataset_tags)) self._allocate_buffers(self.batch_size) self._tags = list(dataset_tags) self._buffer.zero_() self._buffer_fill.zero_() def predict(self, neural_observations: np.ndarray) -> np.ndarray: """One evaluator step. Parameters ---------- neural_observations : np.ndarray Shape (batch_size, n_channels). Binned spike counts for one 20 ms bin per batch slot. Returns ------- np.ndarray Shape (batch_size, 16). Predicted EMG envelope for this bin. """ if neural_observations.ndim != 2: raise ValueError( f"Expected 2D (B, n_channels) input, got shape {neural_observations.shape}" ) B, C = neural_observations.shape if B != self.batch_size: self.set_batch_size(B) self._allocate_buffers(B) if C != self.n_channels: raise ValueError( f"n_channels mismatch: got {C}, expected {self.n_channels}" ) # Roll the buffer: shift left by 1, append new observation at last bin. self._buffer[:, :-1, :] = self._buffer[:, 1:, :].clone() self._buffer[:, -1, :] = torch.from_numpy(neural_observations).to( self._buffer.device, dtype=self._buffer.dtype ) self._buffer_fill = torch.clamp(self._buffer_fill + 1, max=self.context_T) # Run model; EMG envelope prediction of last bin per slot. with torch.no_grad(): emg_last = self._forward_last_bin(self._buffer) # (B, 16) float32 return emg_last.detach().cpu().numpy().astype(np.float32) def on_done(self, dones: np.ndarray) -> None: """No per-trial state in continual M1 mode; buffer spans whole session.""" # FALCON M1 evaluator runs in continual mode; cross-trial neural # history is preserved within a session, only session boundaries # zero the buffer (handled by `reset`). return None # ---------------- internals ---------------- def _allocate_buffers(self, B: int) -> None: self._buffer = torch.zeros( B, self.context_T, self.n_channels, dtype=torch.int64, device=self.device, ) self._buffer_fill = torch.zeros(B, dtype=torch.int64, device=self.device) @torch.no_grad() def _forward_last_bin(self, spike_buf: Tensor) -> Tensor: """Full-context joint forward, return (B, 16) EMG at the last bin. Parameters ---------- spike_buf : Tensor (B, T, N) int """ B, T, N = spike_buf.shape # EMG envelope placeholder (zeros) for pretrained joint encoder. # At inference the EMG head is purely generative; real EMG is unseen. emg_in = torch.zeros(B, T, 16, device=self.device, dtype=torch.float32) out = self.module(spike_buf, emg_in) # emg_pred shape: (B, T*M), muscle-major at each t: # index (t, i) -> t * M + i # We want (B, M) at t = T - 1: M = out["M"] last = out["emg_pred"].view(B, T, M)[:, -1, :] # (B, 16) return last.float()