|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- pytorch |
|
|
- addressed-state-attention |
|
|
- interpretable-ai |
|
|
- mechanistic-interpretability |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
# Addressed State Attention (ASA) |
|
|
|
|
|
Interpretable slot-based attention achieving competitive language modeling performance. |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
```python |
|
|
# Install directly from GitHub |
|
|
!pip install git+https://github.com/DigitalDaimyo/AddressedStateAttention.git |
|
|
|
|
|
from asa import load_asm_checkpoint, generate |
|
|
from transformers import AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download checkpoint from Hugging Face |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="DigitalDaimyo/AddressedStateAttention", |
|
|
filename="checkpoints/fineweb_187M_75k.pt" |
|
|
) |
|
|
|
|
|
# Load checkpoint |
|
|
model, cfg, ckpt = load_asm_checkpoint( |
|
|
ckpt_path, |
|
|
mode="analysis" |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
|
# Generate text |
|
|
print(generate(model, tokenizer, "Once upon a time")) |
|
|
|
|
|
Performance |
|
|
FineWeb, 187M params: 3.73 val loss / 41.6 PPL (75k steps•32 batch•1024 seq) |
|
|
Architecture: 21 layers, 768d, 12 heads, 16 slots |
|
|
Links |
|
|
Code: https://github.com/DigitalDaimyo/AddressedStateAttention |
|
|
Paper: https://github.com/DigitalDaimyo/AddressedStateAttention/paper_drafts |