| """Minimal FALCON M1 decoder wrapper for CortexFM (reference implementation). |
| |
| This is a self-contained reference that mirrors |
| `src/cortex_fm/eval/falcon_m1_decoder.py` from the source repository, intended |
| to be uploaded alongside the Hugging Face checkpoint so downstream users can |
| evaluate `epoch28-0.2599.ckpt` on the FALCON M1 benchmark without cloning the |
| full project. |
| |
| Dependencies |
| ------------ |
| - torch >= 2.10 |
| - pytorch_lightning |
| - numpy |
| - falcon-challenge == 1.0.2 |
| - cortex_fm (installed from the GitHub companion repository or PyPI mirror) |
| |
| The wrapper implements `falcon_challenge.interface.BCIDecoder` for the M1 |
| task: |
| * input : per-timestep spike counts, shape (batch_size, 64) |
| * output : per-timestep EMG envelope predictions, shape (batch_size, 16) |
| |
| The encoder is a 64-bin (1.28 s) window model. The FALCON evaluator streams |
| data one bin at a time, so this wrapper maintains a rolling 64-bin spike |
| buffer per batch slot and returns the last-bin EMG prediction. |
| |
| For session transitions (`reset(dataset_tags)`), the buffer is zeroed per |
| slot. This honors the "no history leakage across sessions" contract of the |
| FALCON continual-evaluation protocol. |
| |
| Example |
| ------- |
| from huggingface_hub import hf_hub_download |
| from falcon_challenge.evaluator import FalconEvaluator |
| from falcon_challenge.config import FalconConfig, FalconTask |
| |
| ckpt = hf_hub_download( |
| repo_id="<USERNAME>/CortexFM", |
| filename="epoch28-0.2599.ckpt", |
| ) |
| decoder = CortexFMFalconDecoder(ckpt_path=ckpt, batch_size=4) |
| evaluator = FalconEvaluator( |
| eval_remote=False, |
| split="m1", |
| verbose=True, |
| ) |
| metrics = evaluator.evaluate(decoder, phase="minival") |
| print(metrics) |
| |
| See the upstream `cortex_fm.eval.run_falcon_m1` script in the project's |
| GitHub repository for a complete CLI harness including auxiliary co-bps |
| computation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| from torch import Tensor |
|
|
| from falcon_challenge.config import FalconConfig, FalconTask |
| from falcon_challenge.interface import BCIDecoder |
|
|
| from cortex_fm.training import CortexFMPretrainModule |
|
|
|
|
| class CortexFMFalconDecoder(BCIDecoder): |
| """FALCON M1 decoder backed by a pretrained CortexFM checkpoint. |
| |
| Parameters |
| ---------- |
| ckpt_path : str | Path |
| Lightning checkpoint produced by `cortex_fm.run_pretrain`. For the |
| public release this is the file `epoch28-0.2599.ckpt` downloaded from |
| Hugging Face Hub. |
| task_config : FalconConfig | None |
| Defaults to `FalconConfig(task=FalconTask.m1, bin_size_ms=20)`. |
| batch_size : int |
| Must match the evaluator batch size (M1 recommended: 4). |
| device : str |
| "cuda" if available else "cpu". |
| context_T : int |
| Rolling window length in 20 ms bins. Should match pretraining context |
| (64, i.e. 1.28 s). |
| """ |
|
|
| def __init__( |
| self, |
| ckpt_path: str | Path, |
| task_config: FalconConfig | None = None, |
| batch_size: int = 4, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| context_T: int = 64, |
| ): |
| if task_config is None: |
| task_config = FalconConfig(task=FalconTask.m1, bin_size_ms=20) |
| if task_config.task != FalconTask.m1: |
| raise ValueError( |
| f"CortexFMFalconDecoder targets FalconTask.m1; got {task_config.task}" |
| ) |
|
|
| super().__init__(task_config=task_config, batch_size=batch_size) |
|
|
| self.ckpt_path = Path(ckpt_path).expanduser().resolve() |
| if not self.ckpt_path.exists(): |
| raise FileNotFoundError(f"Checkpoint not found: {self.ckpt_path}") |
|
|
| self.device = torch.device(device) |
| self.context_T = int(context_T) |
|
|
| |
| self.module = CortexFMPretrainModule.load_from_checkpoint( |
| str(self.ckpt_path), |
| map_location=self.device, |
| strict=True, |
| ) |
| self.module.to(self.device) |
| self.module.eval() |
| for p in self.module.parameters(): |
| p.requires_grad_(False) |
|
|
| |
| self.n_channels: int = int(task_config.n_channels) |
| self.out_dim: int = int(task_config.out_dim) |
|
|
| |
| self._buffer: Tensor | None = None |
| self._buffer_fill: Tensor | None = None |
| self._tags: List[str] = [] |
|
|
| self._allocate_buffers(self.batch_size) |
|
|
| |
|
|
| def reset(self, dataset_tags: List[str] = [""]) -> None: |
| """Reset rolling buffers (called when at least one slot switches file).""" |
| if len(dataset_tags) != self.batch_size: |
| self.set_batch_size(len(dataset_tags)) |
| self._allocate_buffers(self.batch_size) |
| self._tags = list(dataset_tags) |
| self._buffer.zero_() |
| self._buffer_fill.zero_() |
|
|
| def predict(self, neural_observations: np.ndarray) -> np.ndarray: |
| """One evaluator step. |
| |
| Parameters |
| ---------- |
| neural_observations : np.ndarray |
| Shape (batch_size, n_channels). Binned spike counts for one 20 ms |
| bin per batch slot. |
| |
| Returns |
| ------- |
| np.ndarray |
| Shape (batch_size, 16). Predicted EMG envelope for this bin. |
| """ |
| if neural_observations.ndim != 2: |
| raise ValueError( |
| f"Expected 2D (B, n_channels) input, got shape {neural_observations.shape}" |
| ) |
| B, C = neural_observations.shape |
| if B != self.batch_size: |
| self.set_batch_size(B) |
| self._allocate_buffers(B) |
| if C != self.n_channels: |
| raise ValueError( |
| f"n_channels mismatch: got {C}, expected {self.n_channels}" |
| ) |
|
|
| |
| self._buffer[:, :-1, :] = self._buffer[:, 1:, :].clone() |
| self._buffer[:, -1, :] = torch.from_numpy(neural_observations).to( |
| self._buffer.device, dtype=self._buffer.dtype |
| ) |
| self._buffer_fill = torch.clamp(self._buffer_fill + 1, max=self.context_T) |
|
|
| |
| with torch.no_grad(): |
| emg_last = self._forward_last_bin(self._buffer) |
| return emg_last.detach().cpu().numpy().astype(np.float32) |
|
|
| def on_done(self, dones: np.ndarray) -> None: |
| """No per-trial state in continual M1 mode; buffer spans whole session.""" |
| |
| |
| |
| return None |
|
|
| |
|
|
| def _allocate_buffers(self, B: int) -> None: |
| self._buffer = torch.zeros( |
| B, self.context_T, self.n_channels, |
| dtype=torch.int64, device=self.device, |
| ) |
| self._buffer_fill = torch.zeros(B, dtype=torch.int64, device=self.device) |
|
|
| @torch.no_grad() |
| def _forward_last_bin(self, spike_buf: Tensor) -> Tensor: |
| """Full-context joint forward, return (B, 16) EMG at the last bin. |
| |
| Parameters |
| ---------- |
| spike_buf : Tensor (B, T, N) int |
| """ |
| B, T, N = spike_buf.shape |
| |
| |
| emg_in = torch.zeros(B, T, 16, device=self.device, dtype=torch.float32) |
|
|
| out = self.module(spike_buf, emg_in) |
| |
| |
| |
| M = out["M"] |
| last = out["emg_pred"].view(B, T, M)[:, -1, :] |
| return last.float() |
|
|