CortexFM / benchmark_wrapper /falcon_m1_decoder_minimal.py
newempire1101's picture
Add benchmark_wrapper/falcon_m1_decoder_minimal.py
0b3aba5 verified
Raw
History Blame Contribute Delete
8.1 kB
"""Minimal FALCON M1 decoder wrapper for CortexFM (reference implementation).
This is a self-contained reference that mirrors
`src/cortex_fm/eval/falcon_m1_decoder.py` from the source repository, intended
to be uploaded alongside the Hugging Face checkpoint so downstream users can
evaluate `epoch28-0.2599.ckpt` on the FALCON M1 benchmark without cloning the
full project.
Dependencies
------------
- torch >= 2.10
- pytorch_lightning
- numpy
- falcon-challenge == 1.0.2
- cortex_fm (installed from the GitHub companion repository or PyPI mirror)
The wrapper implements `falcon_challenge.interface.BCIDecoder` for the M1
task:
* input : per-timestep spike counts, shape (batch_size, 64)
* output : per-timestep EMG envelope predictions, shape (batch_size, 16)
The encoder is a 64-bin (1.28 s) window model. The FALCON evaluator streams
data one bin at a time, so this wrapper maintains a rolling 64-bin spike
buffer per batch slot and returns the last-bin EMG prediction.
For session transitions (`reset(dataset_tags)`), the buffer is zeroed per
slot. This honors the "no history leakage across sessions" contract of the
FALCON continual-evaluation protocol.
Example
-------
from huggingface_hub import hf_hub_download
from falcon_challenge.evaluator import FalconEvaluator
from falcon_challenge.config import FalconConfig, FalconTask
ckpt = hf_hub_download(
repo_id="<USERNAME>/CortexFM",
filename="epoch28-0.2599.ckpt",
)
decoder = CortexFMFalconDecoder(ckpt_path=ckpt, batch_size=4)
evaluator = FalconEvaluator(
eval_remote=False,
split="m1",
verbose=True,
)
metrics = evaluator.evaluate(decoder, phase="minival")
print(metrics)
See the upstream `cortex_fm.eval.run_falcon_m1` script in the project's
GitHub repository for a complete CLI harness including auxiliary co-bps
computation.
"""
from __future__ import annotations
from pathlib import Path
from typing import List
import numpy as np
import torch
from torch import Tensor
from falcon_challenge.config import FalconConfig, FalconTask
from falcon_challenge.interface import BCIDecoder
from cortex_fm.training import CortexFMPretrainModule
class CortexFMFalconDecoder(BCIDecoder):
"""FALCON M1 decoder backed by a pretrained CortexFM checkpoint.
Parameters
----------
ckpt_path : str | Path
Lightning checkpoint produced by `cortex_fm.run_pretrain`. For the
public release this is the file `epoch28-0.2599.ckpt` downloaded from
Hugging Face Hub.
task_config : FalconConfig | None
Defaults to `FalconConfig(task=FalconTask.m1, bin_size_ms=20)`.
batch_size : int
Must match the evaluator batch size (M1 recommended: 4).
device : str
"cuda" if available else "cpu".
context_T : int
Rolling window length in 20 ms bins. Should match pretraining context
(64, i.e. 1.28 s).
"""
def __init__(
self,
ckpt_path: str | Path,
task_config: FalconConfig | None = None,
batch_size: int = 4,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
context_T: int = 64,
):
if task_config is None:
task_config = FalconConfig(task=FalconTask.m1, bin_size_ms=20)
if task_config.task != FalconTask.m1:
raise ValueError(
f"CortexFMFalconDecoder targets FalconTask.m1; got {task_config.task}"
)
super().__init__(task_config=task_config, batch_size=batch_size)
self.ckpt_path = Path(ckpt_path).expanduser().resolve()
if not self.ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {self.ckpt_path}")
self.device = torch.device(device)
self.context_T = int(context_T)
# Load Lightning module (strict so we catch silent state_dict drift)
self.module = CortexFMPretrainModule.load_from_checkpoint(
str(self.ckpt_path),
map_location=self.device,
strict=True,
)
self.module.to(self.device)
self.module.eval()
for p in self.module.parameters():
p.requires_grad_(False)
# Task dimensions
self.n_channels: int = int(task_config.n_channels) # 64 for M1
self.out_dim: int = int(task_config.out_dim) # 16 for M1
# Rolling spike buffers (batch_size, context_T, n_channels), int
self._buffer: Tensor | None = None
self._buffer_fill: Tensor | None = None
self._tags: List[str] = []
self._allocate_buffers(self.batch_size)
# ---------------- public API (FALCON) ----------------
def reset(self, dataset_tags: List[str] = [""]) -> None:
"""Reset rolling buffers (called when at least one slot switches file)."""
if len(dataset_tags) != self.batch_size:
self.set_batch_size(len(dataset_tags))
self._allocate_buffers(self.batch_size)
self._tags = list(dataset_tags)
self._buffer.zero_()
self._buffer_fill.zero_()
def predict(self, neural_observations: np.ndarray) -> np.ndarray:
"""One evaluator step.
Parameters
----------
neural_observations : np.ndarray
Shape (batch_size, n_channels). Binned spike counts for one 20 ms
bin per batch slot.
Returns
-------
np.ndarray
Shape (batch_size, 16). Predicted EMG envelope for this bin.
"""
if neural_observations.ndim != 2:
raise ValueError(
f"Expected 2D (B, n_channels) input, got shape {neural_observations.shape}"
)
B, C = neural_observations.shape
if B != self.batch_size:
self.set_batch_size(B)
self._allocate_buffers(B)
if C != self.n_channels:
raise ValueError(
f"n_channels mismatch: got {C}, expected {self.n_channels}"
)
# Roll the buffer: shift left by 1, append new observation at last bin.
self._buffer[:, :-1, :] = self._buffer[:, 1:, :].clone()
self._buffer[:, -1, :] = torch.from_numpy(neural_observations).to(
self._buffer.device, dtype=self._buffer.dtype
)
self._buffer_fill = torch.clamp(self._buffer_fill + 1, max=self.context_T)
# Run model; EMG envelope prediction of last bin per slot.
with torch.no_grad():
emg_last = self._forward_last_bin(self._buffer) # (B, 16) float32
return emg_last.detach().cpu().numpy().astype(np.float32)
def on_done(self, dones: np.ndarray) -> None:
"""No per-trial state in continual M1 mode; buffer spans whole session."""
# FALCON M1 evaluator runs in continual mode; cross-trial neural
# history is preserved within a session, only session boundaries
# zero the buffer (handled by `reset`).
return None
# ---------------- internals ----------------
def _allocate_buffers(self, B: int) -> None:
self._buffer = torch.zeros(
B, self.context_T, self.n_channels,
dtype=torch.int64, device=self.device,
)
self._buffer_fill = torch.zeros(B, dtype=torch.int64, device=self.device)
@torch.no_grad()
def _forward_last_bin(self, spike_buf: Tensor) -> Tensor:
"""Full-context joint forward, return (B, 16) EMG at the last bin.
Parameters
----------
spike_buf : Tensor (B, T, N) int
"""
B, T, N = spike_buf.shape
# EMG envelope placeholder (zeros) for pretrained joint encoder.
# At inference the EMG head is purely generative; real EMG is unseen.
emg_in = torch.zeros(B, T, 16, device=self.device, dtype=torch.float32)
out = self.module(spike_buf, emg_in)
# emg_pred shape: (B, T*M), muscle-major at each t:
# index (t, i) -> t * M + i
# We want (B, M) at t = T - 1:
M = out["M"]
last = out["emg_pred"].view(B, T, M)[:, -1, :] # (B, 16)
return last.float()