File size: 3,614 Bytes
e733120 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | ---
license: mit
library_name: pytorch
tags:
- sparse-autoencoder
- interpretability
- llama-3.2
- qwen-2.5
- mechanism-interpretability
pipeline_tag: feature-extraction
language:
- en
base_model:
- Qwen/Qwen2.5-0.5B
- meta-llama/Llama-3.2-1B
---
# Model Card for SSAE Checkpoints
This is the official model repository for the paper **"Step-Level Sparse Autoencoder for Reasoning Process Interpretation"**.
This repository contains the trained **Step-Level Sparse Autoencoder (SSAE)** checkpoints.
- **Paper:** [Arxiv Link Here]()
- **Code:** [GitHub Link Here]()
- **Collection:** [HuggingFace]()
## Model Overview
The checkpoints are provided as PyTorch state dictionaries (`.pt` files). Each file represents an SSAE trained on a specific **Base Model** using a specific **Dataset**.
### Naming Convention
The filenames follow this structure:
`{Dataset}_{BaseModel}_{SparsityConfig}.pt`
- **Dataset:** Source data used for training (e.g., `gsm8k`, `numina`, `opencodeinstruct`).
- **Base Model:** The LLM whose activations were encoded (e.g., `Llama3.2-1b`, `Qwen2.5-0.5b`).
- **SparsityConfig:** Target sparsity (e.g., `spar-10` indicates target sparisty (`tau_{spar}`) equals 10.)
## Checkpoints List
Below is the list of available checkpoints in this repository:
| Filename | Base Model | Training Dataset | Description |
| :--- | :--- | :--- | :--- |
| `gsm8k-385k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | GSM8K | SSAE trained on Llama-3.2-1B using GSM8K-385K. |
| `gsm8k-385k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | GSM8K | SSAE trained on Qwen-2.5-0.5B using GSM8K-385K. |
| `numina-859k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | Numina | SSAE trained on Qwen-2.5-0.5B using Numina-859K. |
| `opencodeinstruct-36k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | OpenCodeInstruct | SSAE trained on Llama-3.2-1B using OpenCodeInstruct-36K. |
| `opencodeinstruct-36k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | OpenCodeInstruct | SSAE trained on Qwen-2.5-0.5B using OpenCodeInstruct-36K. |
## Usage
The provided `.pt` files contain not only the model weights but also the training configuration and metadata.
Structure of the checkpoint dictionary:
- `model`: The model state dictionary (weights).
- `config`: Configuration dictionary (sparsity factor, etc.).
- `encoder_name` / `decoder_name`: Names of the base models used.
- `global_step`: Training step count.
### Loading Code Example
```python
import torch
from huggingface_hub import hf_hub_download
# 1. Download the checkpoint
repo_id = "Miaow-Lab/SSAE-Models"
filename = "gsm8k-385k_Llama3.2-1b_spar-10.pt" # Example filename
checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
# 2. Load the full checkpoint dictionary
# Note: map_location="cpu" is recommended for initial loading
checkpoint = torch.load(checkpoint_path, map_location="cpu")
print(f"Loaded checkpoint (Step: {checkpoint.get('global_step', 'Unknown')})")
print(f"Config: {checkpoint.get('config')}")
# 3. Initialize your model
# Use the metadata from the checkpoint to ensure correct initialization arguments
# model = MyModel(
# tokenizer=...,
# sparsity_factor=checkpoint['config'].get('sparsity_factor'), # Adjust key based on your config structure
# init_from=(checkpoint['encoder_name'], checkpoint['decoder_name'])
# )
# 4. Load the weights
# CRITICAL: The weights are stored under the "model" key
model.load_state_dict(checkpoint["model"], strict=True)
model.to("cuda") # Move to GPU if needed
model.eval()
```
## Citation
If you use these models or the associated code in your research, please cite our paper:
```bibtex
```
|