TorresYang commited on
Commit
e733120
·
verified ·
1 Parent(s): 1050ca4

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +101 -0
README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ tags:
5
+ - sparse-autoencoder
6
+ - interpretability
7
+ - llama-3.2
8
+ - qwen-2.5
9
+ - mechanism-interpretability
10
+ pipeline_tag: feature-extraction
11
+ language:
12
+ - en
13
+ base_model:
14
+ - Qwen/Qwen2.5-0.5B
15
+ - meta-llama/Llama-3.2-1B
16
+ ---
17
+
18
+ # Model Card for SSAE Checkpoints
19
+
20
+ This is the official model repository for the paper **"Step-Level Sparse Autoencoder for Reasoning Process Interpretation"**.
21
+
22
+ This repository contains the trained **Step-Level Sparse Autoencoder (SSAE)** checkpoints.
23
+
24
+ - **Paper:** [Arxiv Link Here]()
25
+ - **Code:** [GitHub Link Here]()
26
+ - **Collection:** [HuggingFace]()
27
+
28
+ ## Model Overview
29
+
30
+ The checkpoints are provided as PyTorch state dictionaries (`.pt` files). Each file represents an SSAE trained on a specific **Base Model** using a specific **Dataset**.
31
+
32
+ ### Naming Convention
33
+ The filenames follow this structure:
34
+ `{Dataset}_{BaseModel}_{SparsityConfig}.pt`
35
+
36
+ - **Dataset:** Source data used for training (e.g., `gsm8k`, `numina`, `opencodeinstruct`).
37
+ - **Base Model:** The LLM whose activations were encoded (e.g., `Llama3.2-1b`, `Qwen2.5-0.5b`).
38
+ - **SparsityConfig:** Target sparsity (e.g., `spar-10` indicates target sparisty (`tau_{spar}`) equals 10.)
39
+
40
+ ## Checkpoints List
41
+
42
+ Below is the list of available checkpoints in this repository:
43
+
44
+ | Filename | Base Model | Training Dataset | Description |
45
+ | :--- | :--- | :--- | :--- |
46
+ | `gsm8k-385k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | GSM8K | SSAE trained on Llama-3.2-1B using GSM8K-385K. |
47
+ | `gsm8k-385k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | GSM8K | SSAE trained on Qwen-2.5-0.5B using GSM8K-385K. |
48
+ | `numina-859k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | Numina | SSAE trained on Qwen-2.5-0.5B using Numina-859K. |
49
+ | `opencodeinstruct-36k_Llama3.2-1b_spar-10.pt` | Llama-3.2-1B | OpenCodeInstruct | SSAE trained on Llama-3.2-1B using OpenCodeInstruct-36K. |
50
+ | `opencodeinstruct-36k_Qwen2.5-0.5b_spar-10.pt` | Qwen-2.5-0.5B | OpenCodeInstruct | SSAE trained on Qwen-2.5-0.5B using OpenCodeInstruct-36K. |
51
+
52
+ ## Usage
53
+
54
+ The provided `.pt` files contain not only the model weights but also the training configuration and metadata.
55
+
56
+ Structure of the checkpoint dictionary:
57
+ - `model`: The model state dictionary (weights).
58
+ - `config`: Configuration dictionary (sparsity factor, etc.).
59
+ - `encoder_name` / `decoder_name`: Names of the base models used.
60
+ - `global_step`: Training step count.
61
+
62
+ ### Loading Code Example
63
+
64
+ ```python
65
+ import torch
66
+ from huggingface_hub import hf_hub_download
67
+
68
+ # 1. Download the checkpoint
69
+ repo_id = "Miaow-Lab/SSAE-Models"
70
+ filename = "gsm8k-385k_Llama3.2-1b_spar-10.pt" # Example filename
71
+
72
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
73
+
74
+ # 2. Load the full checkpoint dictionary
75
+ # Note: map_location="cpu" is recommended for initial loading
76
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
77
+
78
+ print(f"Loaded checkpoint (Step: {checkpoint.get('global_step', 'Unknown')})")
79
+ print(f"Config: {checkpoint.get('config')}")
80
+
81
+ # 3. Initialize your model
82
+ # Use the metadata from the checkpoint to ensure correct initialization arguments
83
+ # model = MyModel(
84
+ # tokenizer=...,
85
+ # sparsity_factor=checkpoint['config'].get('sparsity_factor'), # Adjust key based on your config structure
86
+ # init_from=(checkpoint['encoder_name'], checkpoint['decoder_name'])
87
+ # )
88
+
89
+ # 4. Load the weights
90
+ # CRITICAL: The weights are stored under the "model" key
91
+ model.load_state_dict(checkpoint["model"], strict=True)
92
+
93
+ model.to("cuda") # Move to GPU if needed
94
+ model.eval()
95
+ ```
96
+
97
+ ## Citation
98
+ If you use these models or the associated code in your research, please cite our paper:
99
+ ```bibtex
100
+
101
+ ```