TBX5 Sparse Autoencoder (k=128)
Model Description
This is a Sparse Autoencoder (SAE) trained on Evo2 embeddings for detecting TBX5 transcription factor binding patterns in cardiac genes.
Architecture
- Base Model: Evo2-7B Layer 26 embeddings
- SAE Type: BatchTopKTiedSAE
- Input Dimension: 4,096 (Evo2 embeddings)
- Hidden Dimension: 32,768 (8x expansion)
- Sparsity: k=128 (99.61% sparse)
- Active Features: 128 per position (2x more than k=64)
Training Details
- Training Data: 1,204,224 positions from 50 cardiac genes
- Loss Function: MSE
- Final Loss: 0.043025
- Epochs: 3
- Device: CUDA
Key Features
Top Active Features (by activation rate)
- Feature 2241: 96.7% activation rate
- Feature 17017: 95.7% activation rate
- Feature 18612: 94.7% activation rate
- Feature 5366: 92.6% activation rate
- Feature 29391: 91.2% activation rate
Comparison with k=64
| Model | Sparsity | Active/Position | Use Case |
|---|---|---|---|
| k=64 | 99.80% | 64 features | Clear motif identification |
| k=128 | 99.61% | 128 features | Complex pattern detection |
Usage
import torch
# Load model
checkpoint = torch.load('sae_model_layer26_topk128.pt', map_location='cpu', weights_only=False)
# Create model
model = BatchTopKTiedSAE(
d_in=4096,
d_hidden=32768,
k=128,
device='cuda',
dtype=torch.float32
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Use with Evo2 embeddings
# embeddings shape: (batch, 4096)
features = model.encode(embeddings)
Files
sae_model_layer26_topk128.pt- PyTorch model checkpoint (512.1 MB)model_config.json- Model configurationtraining_summary.json- Training statistics
Citation
If you use this model, please cite:
@misc{tbx5_sae_k128_2024,
title={TBX5 Sparse Autoencoder k=128},
author={Anonymous},
year={2024},
publisher={HuggingFace}
}
License
MIT License
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support