| --- |
| license: apache-2.0 |
| tags: |
| - sparse-autoencoder |
| - mechanistic-interpretability |
| - tool-calling |
| - gemma |
| - ministral |
| - qwen |
| arxiv: 2605.18882 |
| --- |
| |
| # toolcalling-sae |
|
|
| TopK Sparse Autoencoder checkpoints from [To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents](https://arxiv.org/abs/2605.18882). |
|
|
| ## Checkpoints |
|
|
| | Model | Layer | Dict Size | k | Stage 1 | Stage 2 | |
| |-------|-------|-----------|---|---------|---------| |
| | gemma-3-1b-it | L17 | 9 216 | 128 | 50M tokens | 5M tokens | |
| | gemma-3-4b-it | L29 | 20 480 | 128 | 50M tokens | 5M tokens | |
| | gemma-4-E2B-it | L30 | 12 288 | 128 | 50M tokens | 5M tokens | |
| | gemma-4-E4B-it | L30 | 20 480 | 128 | 50M tokens | 5M tokens | |
| | Ministral-3-3B-Instruct-2512 | L21 | 24 576 | 128 | 50M tokens | 5M tokens | |
| | Ministral-3-8B-Instruct-2512 | L31 | 32 768 | 128 | 50M tokens | 5M tokens | |
| | Qwen3.5-4B | L25 | 20 480 | 128 | 50M tokens | 5M tokens | |
| | Qwen3.5-9B | L25 | 32 768 | 128 | 50M tokens | 5M tokens | |
|
|
| **Stage 1**: Pre-trained on [OpenWebText2](https://openwebtext2.readthedocs.io/). |
| **Stage 2**: Fine-tuned on tool-calling activations from the [When2Call](https://arxiv.org/abs/2605.18882) benchmark. |
| All checkpoints use `bfloat16` precision. |
|
|
| ## Usage |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| from sae_model import TopKSAE |
| |
| ckpt_path = hf_hub_download( |
| repo_id="SKwra/toolcalling-sae", |
| filename="gemma-3-1b-it/stage2/gemma-3-1b-it-L17-d9216-5M-stage2.pt" |
| ) |
| sae = TopKSAE.load(ckpt_path, device="cuda") |
| ``` |
|
|
| `sae_model.py` is included in this repo. Full code at [GitHub](https://github.com/SKURA502/agent-sae). |
|
|