| | --- |
| | license: mit |
| | library_name: pytorch |
| | tags: |
| | - sparse-autoencoder |
| | - interpretability |
| | - llama-3.2 |
| | - qwen-2.5 |
| | - mechanism-interpretability |
| | pipeline_tag: feature-extraction |
| | language: |
| | - en |
| | base_model: |
| | - Qwen/Qwen2.5-0.5B |
| | - meta-llama/Llama-3.2-1B |
| | --- |
| | |
| | # Model Card for SSAE Checkpoints |
| |
|
| | This is the official model repository for the paper **"Step-Level Sparse Autoencoder for Reasoning Process Interpretation"**. |
| |
|
| | This repository contains the trained **Step-Level Sparse Autoencoder (SSAE)** checkpoints. |
| |
|
| | - **Paper:** [Arxiv Link Here]() |
| | - **Code:** [GitHub Link Here]() |
| | - **Collection:** [HuggingFace]() |
| |
|
| | ## Model Overview |
| |
|
| | The checkpoints are provided as PyTorch state dictionaries (`.pt` files). Each file represents an SSAE trained on a specific **Base Model** using a specific **Dataset**. |
| |
|
| | ### Naming Convention |
| | The filenames follow this structure: |
| | `{Dataset}_{BaseModel}_{SparsityConfig}.pt` |
| |
|
| | - **Dataset:** Source data used for training (e.g., `gsm8k`, `numina`, `opencodeinstruct`). |
| | - **Base Model:** The LLM whose activations were encoded (e.g., `Llama3.2-1b`, `Qwen2.5-0.5b`). |
| | - **SparsityConfig:** Target sparsity (e.g., `spar-10` indicates target sparisty (`tau_{spar}`) equals 10.) |
| |
|
| | ## Checkpoints List |
| |
|
| | Below is the list of available checkpoints in this repository: |
| |
|
| | | Filename | Base Model | Training Dataset | Description | |
| | | :--- | :--- | :--- | :--- | |
| | | `gsm8k-385k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | GSM8K | SSAE trained on Llama-3.2-1B using GSM8K-385K. | |
| | | `gsm8k-385k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | GSM8K | SSAE trained on Qwen-2.5-0.5B using GSM8K-385K. | |
| | | `numina-859k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | Numina | SSAE trained on Qwen-2.5-0.5B using Numina-859K. | |
| | | `opencodeinstruct-36k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | OpenCodeInstruct | SSAE trained on Llama-3.2-1B using OpenCodeInstruct-36K. | |
| | | `opencodeinstruct-36k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | OpenCodeInstruct | SSAE trained on Qwen-2.5-0.5B using OpenCodeInstruct-36K. | |
| |
|
| | ## Usage |
| |
|
| | The provided `.pt` files contain not only the model weights but also the training configuration and metadata. |
| |
|
| | Structure of the checkpoint dictionary: |
| | - `model`: The model state dictionary (weights). |
| | - `config`: Configuration dictionary (sparsity factor, etc.). |
| | - `encoder_name` / `decoder_name`: Names of the base models used. |
| | - `global_step`: Training step count. |
| |
|
| | ### Loading Code Example |
| |
|
| | ```python |
| | import torch |
| | from huggingface_hub import hf_hub_download |
| | |
| | # 1. Download the checkpoint |
| | repo_id = "Miaow-Lab/SSAE-Models" |
| | filename = "gsm8k-385k_Llama3.2-1b_spar-10.pt" # Example filename |
| | |
| | checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) |
| | |
| | # 2. Load the full checkpoint dictionary |
| | # Note: map_location="cpu" is recommended for initial loading |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| | |
| | print(f"Loaded checkpoint (Step: {checkpoint.get('global_step', 'Unknown')})") |
| | print(f"Config: {checkpoint.get('config')}") |
| | |
| | # 3. Initialize your model |
| | # Use the metadata from the checkpoint to ensure correct initialization arguments |
| | # model = MyModel( |
| | # tokenizer=..., |
| | # sparsity_factor=checkpoint['config'].get('sparsity_factor'), # Adjust key based on your config structure |
| | # init_from=(checkpoint['encoder_name'], checkpoint['decoder_name']) |
| | # ) |
| | |
| | # 4. Load the weights |
| | # CRITICAL: The weights are stored under the "model" key |
| | model.load_state_dict(checkpoint["model"], strict=True) |
| | |
| | model.to("cuda") # Move to GPU if needed |
| | model.eval() |
| | ``` |
| |
|
| | ## Citation |
| | If you use these models or the associated code in your research, please cite our paper: |
| | ```bibtex |
| | |
| | ``` |
| |
|