Spaces:
Sleeping
Sleeping
| """Braindecode Model Explorer — interactive architecture browser. | |
| This Hugging Face Space lets users browse all 57 EEG model architectures | |
| in braindecode, read the rendered docstring (parameters, references, | |
| architecture figure), and instantiate any model with custom signal | |
| shape to inspect its parameter count and layer summary. | |
| No pretrained weights are loaded — this is a pure architecture explorer. | |
| Run locally: | |
| pip install -r requirements.txt | |
| python app.py | |
| """ | |
| from __future__ import annotations | |
| import inspect | |
| from typing import Any | |
| import gradio as gr | |
| import torch | |
| from torchinfo import summary | |
| import braindecode.models as M | |
| from braindecode.models.base import EEGModuleMixin | |
| from docstring_renderer import ( | |
| get_signature_str, | |
| get_source_link, | |
| render_docstring_html, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Catalog: discover every EEGModuleMixin subclass exported by braindecode. | |
| # --------------------------------------------------------------------------- | |
| def _discover_models() -> dict[str, type]: | |
| catalog: dict[str, type] = {} | |
| for name in sorted(getattr(M, "__all__", []) or dir(M)): | |
| if name.startswith("_"): | |
| continue | |
| obj = getattr(M, name, None) | |
| if ( | |
| inspect.isclass(obj) | |
| and issubclass(obj, EEGModuleMixin) | |
| and obj is not EEGModuleMixin | |
| ): | |
| catalog[name] = obj | |
| return catalog | |
| MODELS: dict[str, type] = _discover_models() | |
| MODEL_NAMES: list[str] = sorted(MODELS.keys()) | |
| # --------------------------------------------------------------------------- | |
| # Heuristic defaults for the signal-shape form. Different model families | |
| # expect very different inputs (sleep stagers want 30 s @ 100 Hz; motor- | |
| # imagery models want ~4 s @ 250 Hz). Pick a reasonable default per family. | |
| # --------------------------------------------------------------------------- | |
| DEFAULTS = { | |
| "sleep": dict(n_chans=2, sfreq=100, input_window_seconds=30.0, n_outputs=5), | |
| "biot": dict(n_chans=16, sfreq=200, input_window_seconds=10.0, n_outputs=2), | |
| "bendr": dict(n_chans=20, sfreq=256, input_window_seconds=4.0, n_outputs=2), | |
| "labram": dict(n_chans=22, sfreq=200, input_window_seconds=4.0, n_outputs=2), | |
| "default": dict(n_chans=22, sfreq=250, input_window_seconds=4.0, n_outputs=4), | |
| } | |
| def _defaults_for(name: str) -> dict[str, Any]: | |
| lower = name.lower() | |
| if "sleep" in lower or name in {"USleep", "AttnSleep", "DeepSleepNet"}: | |
| return DEFAULTS["sleep"] | |
| if "biot" in lower: | |
| return DEFAULTS["biot"] | |
| if "bendr" in lower: | |
| return DEFAULTS["bendr"] | |
| if "labram" in lower or "cbramod" in lower or "eegpt" in lower: | |
| return DEFAULTS["labram"] | |
| return DEFAULTS["default"] | |
| # --------------------------------------------------------------------------- | |
| # Rendering helpers | |
| # --------------------------------------------------------------------------- | |
| def _info_card(name: str) -> str: | |
| cls = MODELS[name] | |
| sig = get_signature_str(cls) | |
| link = get_source_link(cls) | |
| link_html = ( | |
| f'<a href="{link}" target="_blank" ' | |
| f'style="color:#0072B2;text-decoration:none;">View source on GitHub →</a>' | |
| if link | |
| else "" | |
| ) | |
| return ( | |
| f"<div style='background:#f6f8fa;padding:12px 16px;border-radius:8px;" | |
| f"border:1px solid #d0d7de;margin-bottom:12px;'>" | |
| f"<div style='font-size:1.3em;font-weight:600;color:#0072B2;" | |
| f"margin-bottom:4px;'>{name}</div>" | |
| f"<div style='font-family:Menlo,Consolas,monospace;font-size:0.82em;" | |
| f"color:#475569;margin-bottom:6px;word-break:break-all;'>{sig}</div>" | |
| f"<div style='font-size:0.9em;'>{link_html}</div>" | |
| f"</div>" | |
| ) | |
| def show_model(name: str) -> tuple[str, str, dict, dict, dict, dict]: | |
| """Update info card, rendered docstring, and form defaults.""" | |
| if name not in MODELS: | |
| return "", "", {}, {}, {}, {} | |
| info = _info_card(name) | |
| doc_html = render_docstring_html(MODELS[name].__doc__) | |
| d = _defaults_for(name) | |
| return ( | |
| info, | |
| doc_html, | |
| gr.update(value=d["n_chans"]), | |
| gr.update(value=d["sfreq"]), | |
| gr.update(value=d["input_window_seconds"]), | |
| gr.update(value=d["n_outputs"]), | |
| ) | |
| def instantiate( | |
| name: str, | |
| n_chans: int, | |
| sfreq: float, | |
| window_s: float, | |
| n_outputs: int, | |
| ) -> tuple[str, str]: | |
| """Instantiate the selected model and run a dummy forward pass.""" | |
| if name not in MODELS: | |
| return "Pick a model first.", "" | |
| cls = MODELS[name] | |
| n_times = int(round(window_s * sfreq)) | |
| kwargs = dict( | |
| n_chans=int(n_chans), | |
| sfreq=float(sfreq), | |
| input_window_seconds=float(window_s), | |
| n_outputs=int(n_outputs), | |
| ) | |
| # Drop kwargs the class does not accept (some models do not take all | |
| # of these in __init__; the mixin infers what it can). | |
| sig_params = set(inspect.signature(cls.__init__).parameters) | |
| kwargs = {k: v for k, v in kwargs.items() if k in sig_params} | |
| try: | |
| model = cls(**kwargs) | |
| except Exception as exc: # noqa: BLE001 — surface any constructor error | |
| return f"❌ **Failed to instantiate `{name}`** with `{kwargs}`:\n```\n{exc}\n```", "" | |
| n_params = sum(p.numel() for p in model.parameters()) | |
| n_train = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| try: | |
| info = summary( | |
| model, | |
| input_size=(1, int(n_chans), n_times), | |
| depth=3, | |
| verbose=0, | |
| col_names=("output_size", "num_params"), | |
| ) | |
| summary_str = str(info) | |
| except Exception as exc: # noqa: BLE001 | |
| summary_str = f"(torchinfo summary unavailable: {exc})" | |
| try: | |
| x = torch.randn(2, int(n_chans), n_times) | |
| with torch.no_grad(): | |
| y = model(x) | |
| out_shape = tuple(y.shape) if hasattr(y, "shape") else type(y).__name__ | |
| except Exception as exc: # noqa: BLE001 | |
| out_shape = f"forward failed: {exc}" | |
| header = ( | |
| f"### `{name}` instantiated\n\n" | |
| f"| metric | value |\n|---|---|\n" | |
| f"| total parameters | **{n_params:,}** |\n" | |
| f"| trainable parameters | {n_train:,} |\n" | |
| f"| input shape | `(batch, {n_chans}, {n_times})` |\n" | |
| f"| output shape | `{out_shape}` |\n" | |
| f"| input window | {window_s} s @ {sfreq} Hz |\n" | |
| ) | |
| return header, f"```\n{summary_str}\n```" | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| INTRO = """ | |
| # Braindecode Model Explorer | |
| Browse **57 EEG / biosignal model architectures** from | |
| [braindecode](https://braindecode.org). Read the rendered docstring, | |
| configure signal shape, and instantiate any model live to inspect its | |
| parameter count and layer summary. | |
| > No pretrained weights are loaded — this is a pure architecture browser. | |
| > For weights, see [`huggingface.co/braindecode`](https://huggingface.co/braindecode). | |
| """ | |
| def build_app() -> gr.Blocks: | |
| with gr.Blocks( | |
| title="Braindecode Model Explorer", | |
| theme=gr.themes.Soft(primary_hue="blue"), | |
| ) as app: | |
| gr.Markdown(INTRO) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dd = gr.Dropdown( | |
| choices=MODEL_NAMES, | |
| value="EEGNetv4", | |
| label="Architecture", | |
| interactive=True, | |
| ) | |
| info_html = gr.HTML(_info_card("EEGNetv4")) | |
| gr.Markdown("### Signal configuration") | |
| with gr.Group(): | |
| n_chans = gr.Number(value=22, label="n_chans", precision=0) | |
| sfreq = gr.Number(value=250, label="sfreq (Hz)") | |
| window_s = gr.Number( | |
| value=4.0, label="input_window_seconds" | |
| ) | |
| n_outputs = gr.Number( | |
| value=4, label="n_outputs", precision=0 | |
| ) | |
| run_btn = gr.Button("Instantiate model", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.TabItem("Documentation"): | |
| doc_html = gr.HTML( | |
| render_docstring_html(MODELS["EEGNetv4"].__doc__) | |
| ) | |
| with gr.TabItem("Live instance"): | |
| result_md = gr.Markdown( | |
| "_Press **Instantiate model** to build the network._" | |
| ) | |
| summary_md = gr.Markdown() | |
| # wiring | |
| model_dd.change( | |
| show_model, | |
| inputs=model_dd, | |
| outputs=[info_html, doc_html, n_chans, sfreq, window_s, n_outputs], | |
| ) | |
| run_btn.click( | |
| instantiate, | |
| inputs=[model_dd, n_chans, sfreq, window_s, n_outputs], | |
| outputs=[result_md, summary_md], | |
| ) | |
| gr.Markdown( | |
| "---\nMade with [braindecode](https://braindecode.org) · " | |
| "Source: [github.com/braindecode/braindecode]" | |
| "(https://github.com/braindecode/braindecode)" | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| build_app().launch() | |