--- license: mit library_name: pytorch tags: - sparse-autoencoder - interpretability - llama-3.2 - qwen-2.5 - mechanism-interpretability pipeline_tag: feature-extraction language: - en base_model: - Qwen/Qwen2.5-0.5B - meta-llama/Llama-3.2-1B --- # Model Card for SSAE Checkpoints This is the official model repository for the paper **"Step-Level Sparse Autoencoder for Reasoning Process Interpretation"**. This repository contains the trained **Step-Level Sparse Autoencoder (SSAE)** checkpoints. - **Paper:** [Arxiv Link Here]() - **Code:** [GitHub Link Here]() - **Collection:** [HuggingFace]() ## Model Overview The checkpoints are provided as PyTorch state dictionaries (`.pt` files). Each file represents an SSAE trained on a specific **Base Model** using a specific **Dataset**. ### Naming Convention The filenames follow this structure: `{Dataset}_{BaseModel}_{SparsityConfig}.pt` - **Dataset:** Source data used for training (e.g., `gsm8k`, `numina`, `opencodeinstruct`). - **Base Model:** The LLM whose activations were encoded (e.g., `Llama3.2-1b`, `Qwen2.5-0.5b`). - **SparsityConfig:** Target sparsity (e.g., `spar-10` indicates target sparisty (`tau_{spar}`) equals 10.) ## Checkpoints List Below is the list of available checkpoints in this repository: | Filename | Base Model | Training Dataset | Description | | :--- | :--- | :--- | :--- | | `gsm8k-385k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | GSM8K | SSAE trained on Llama-3.2-1B using GSM8K-385K. | | `gsm8k-385k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | GSM8K | SSAE trained on Qwen-2.5-0.5B using GSM8K-385K. | | `numina-859k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | Numina | SSAE trained on Qwen-2.5-0.5B using Numina-859K. | | `opencodeinstruct-36k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | OpenCodeInstruct | SSAE trained on Llama-3.2-1B using OpenCodeInstruct-36K. | | `opencodeinstruct-36k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | OpenCodeInstruct | SSAE trained on Qwen-2.5-0.5B using OpenCodeInstruct-36K. | ## Usage The provided `.pt` files contain not only the model weights but also the training configuration and metadata. Structure of the checkpoint dictionary: - `model`: The model state dictionary (weights). - `config`: Configuration dictionary (sparsity factor, etc.). - `encoder_name` / `decoder_name`: Names of the base models used. - `global_step`: Training step count. ### Loading Code Example ```python import torch from huggingface_hub import hf_hub_download # 1. Download the checkpoint repo_id = "Miaow-Lab/SSAE-Models" filename = "gsm8k-385k_Llama3.2-1b_spar-10.pt" # Example filename checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename) # 2. Load the full checkpoint dictionary # Note: map_location="cpu" is recommended for initial loading checkpoint = torch.load(checkpoint_path, map_location="cpu") print(f"Loaded checkpoint (Step: {checkpoint.get('global_step', 'Unknown')})") print(f"Config: {checkpoint.get('config')}") # 3. Initialize your model # Use the metadata from the checkpoint to ensure correct initialization arguments # model = MyModel( # tokenizer=..., # sparsity_factor=checkpoint['config'].get('sparsity_factor'), # Adjust key based on your config structure # init_from=(checkpoint['encoder_name'], checkpoint['decoder_name']) # ) # 4. Load the weights # CRITICAL: The weights are stored under the "model" key model.load_state_dict(checkpoint["model"], strict=True) model.to("cuda") # Move to GPU if needed model.eval() ``` ## Citation If you use these models or the associated code in your research, please cite our paper: ```bibtex ```